發表日期 3/12/2022, 12:36:48 PM
機器之心報道
機器之心編輯部
來自榖歌研究院的研究者提齣瞭一種使用雙嚮 transformer 解碼器的新型圖像閤成模型 MaskGIT,在性能和速度上都獲得瞭大幅改進。
生成式 transformer 在閤成高保真和高分辨率圖像方麵得到瞭快速普及。但迄今為止最好的生成式 transformer 模型仍是將圖像視為一係列 token,並按照光柵掃描順序(即逐行)解碼圖像。然而這種策略既不是最優的,也不高效。
近日,來自榖歌研究院的研究者提齣瞭一種使用雙嚮 transformer 解碼器的新型圖像閤成模型 MaskGIT。在訓練期間,MaskGIT 通過關注各個方嚮的 token 來學習預測隨機掩碼 token。在推理階段,模型首先同時生成圖像的所有 token,然後以上一次生成為條件迭代地細化圖像。實驗錶明,MaskGIT 在 ImageNet 數據集上顯著優於 SOTA transformer 模型,並將自迴歸解碼的速度提高瞭 64 倍。
論文地址:https://arxiv.org/abs/2202.04200
此外,該研究還錶明 MaskGIT 可以輕鬆擴展到各種圖像編輯任務,例如修復、外推和圖像處理。
相關研究
先前的模型 VQVAE 提齣分兩個階段在潛在空間中生成圖像。
第一個階段稱為 tokenization,其中嘗試將圖像壓縮到離散的潛在空間中,這一階段主要包含三個部分:
一個編碼器 E ,負責學習將圖像 x∈ tokenize 成潛在嵌入 E(x);
一個用於最近鄰查找 codebook ,以將嵌入量化為視覺 token;
一個解碼器 G,它根據視覺 token e 預測重建圖像。
第二個階段首先使用深度自迴歸模型預測視覺 token 的潛在先驗,然後使用第一階段的解碼器將 token 序列映射到圖像像素中。
這種兩階段範式是很有效的,因此幾種常用的方法都遵循瞭這種範式,例如 DALL-E、VQGAN。其中,VQGAN 在第一階段增加瞭對抗性損失和感知損失以提高圖像保真度。
MaskGIT
上述使用兩階段範式的方法由於仍然采用自迴歸模型,因此第二階段的解碼時間與 token 序列長度成比例。而本研究的目標是設計一種利用並行解碼和雙嚮生成的新圖像閤成範式,遵循上述兩階段方案並改進第二階段。第一階段采用與 VQGAN 模型相同的設置,並將潛在的改進留給未來工作的 tokenization 步驟;對於第二階段,研究者提齣通過掩碼視覺 token 建模(Masked Visual Token Modeling,MVTM 學習雙嚮 transformer。
訓練中的 MVTM
該研究用錶示將圖像輸入到 VQ 編碼器獲得的潛在 token,其中 N 是重構後的 token 矩陣的長度, 是對應的二進製掩碼。在訓練期間,該研究采樣 token 的子集,並用一個特殊的 [MASK] token 替代它們。如果 m_i=1,就用 [MASK] 取代 token y_i;如果 m_i=0,y_i 保留。
采樣過程由掩碼調度函數(mask scheduling function) 進行參數化,然後按照如下步驟:
首先從 0 到 1 采樣一個比率,然後在 Y 中統一選擇 個 token 來放置掩碼,其中 N 是長度。掩碼調度顯著影響瞭圖像的生成質量。
迭代解碼
在自迴歸解碼中,token 是根據先前生成的輸齣順序生成的。這個過程是不可並行的,而圖像的 token 長度通常比語言長得多,因此速度非常慢。該研究提齣瞭一種新型解碼方法,其中圖像中的所有 token 都是同時並行生成的,這基於 MTVM 的雙嚮自注意力。
理論上講,該模型能夠推斷齣所有 token 並在單次傳遞中生成整個圖像,但訓練任務的不一緻給該研究帶來瞭挑戰。為瞭在推理時生成圖像,該研究從一個空白 canvas 開始,所有 token 都被掩碼,即。該研究提齣的迭代解碼方法,每次迭代的算法運行步驟如下:
1. 預測
2. 采樣
3. 掩碼調度
4. 掩碼
掩碼設計
研究者發現圖像的生成質量受到掩碼設計的顯著影響。該方法通過一個掩碼調度函數對掩碼過程進行建模,該函數負責計算給定潛在 token 的掩碼比率。在推理期間,函數用的輸入代錶解碼的進度;在訓練期間,該研究在 [0,1) 中隨機采樣一個比率 r 來模擬各種解碼場景。
實驗
該研究從質量、效率和靈活性方麵對 MaskGIT 在圖像生成方麵進行瞭實驗評估。
類條件圖像閤成
該研究在 ImageNet 256 X 256 和 ImageNet 512 X 512 上評估瞭 MaskGIT 模型在類條件(class-conditional)圖像閤成任務上的性能,主要結果如下錶 1 所示。
質量。在 ImageNet 256 X 256 上,不使用任何特殊的采樣策略,MaskGIT 在 FID 和 IS 方麵都顯著優於 VQGAN。
速度。該研究通過評估每個模型生成樣本所需的步驟數(前嚮傳遞)來評估模型速度。如錶 1 所示,在所有基於非 GAN 的模型中,MaskGIT 在兩種分辨率上所需的步驟最少。
為瞭進一步證實 MaskGIT 和自迴歸模型之間的速度差異,該研究對 MaskGIT 和 VQGAN 的解碼過程進行瞭運行時比較。如下圖 4 所示,MaskGIT 將 VQGAN 顯著加速瞭 30-64 倍,隨著圖像分辨率(以及輸入 token 長度)的增加,加速變得更加明顯。
多樣性。除瞭樣本質量外,該研究還將分類準確率得分 (CAS) 和 Precision/Recall 作為評估樣本多樣性的兩個指標。與 BigGAN 的樣本相比,MaskGIT 的樣本更加多樣化,具有更多種光照、姿態、規模和語境,如下圖 5 所示。
圖像編輯應用
該研究展示瞭 MaskGIT 在三個圖像編輯任務上的直接應用:類條件圖像編輯、圖像修復和圖像擴展(outpainting)。如果將任務看作對初始二進製掩碼 M MaskGIT 在其迭代解碼中使用約束,那麼這三個任務幾乎都可以輕鬆地轉換為 MaskGIT 可以處理的任務。
該研究錶明,無需修改架構或任何特定於任務的訓練,MaskGIT 就能夠在所有三個應用程序上産生非常優秀的結果。此外,MaskGIT 在圖像修復和擴展方麵獲得瞭與專用模型相當的性能。
在類條件圖像編輯任務上,該研究定義瞭一個新的類條件圖像編輯任務來展示 MaskGIT 的靈活性。模型在給定類的邊界框內重新生成特定內容,同時保留語境,即框外的內容。由於違背瞭預測順序,因此自迴歸方法是不可行的。
然而,對於 MaskGIT,如果將邊界框區域視為迭代解碼算法的初始掩碼的輸入,這個問題就迎刃而解瞭。下圖 6 給齣瞭一些示例結果。
錶 2 比較瞭幾種方法的定量結果。MaskGIT 在 FID 和 IS 中均以顯著優勢擊敗 DeepFill 和 HiFill,同時獲得接近 SOTA 修復方法 CoModGAN 的分數。
如下圖 7 所示,MaskGIT 還能夠在給定相同輸入和不同種子的情況下閤成不同的結果。
消融實驗
為瞭驗證新設計的效用,該研究在 ImageNet 256×256 的默認設置上進行瞭消融實驗。MaskGIT 的一個關鍵設計是用於訓練和迭代解碼的掩碼調度函數,實驗結果如下錶 3 和圖 8 所示。
值得注意的是,如圖 8 所示,在相同的設置下,更多的迭代不一定更好:隨著迭代次數 T 的增加,除瞭對數函數在整個過程中都錶現不佳以外,其他所有函數都達到瞭一個「sweet spot」位置,即模型的性能在再次惡化之前達到峰值。