Soft Diffusion:谷歌新框架從通用擴散過(guò)程中正確調度、學(xué)習和采樣
近來(lái),擴散模型成為 AI 領(lǐng)域的研究熱點(diǎn)。谷歌研究院和 UT-Austin 的研究者在最新的一項研究中充分考慮了「損壞」過(guò)程,并提出了一個(gè)用于更通用損壞過(guò)程的擴散模型設計框架。
我們知道,基于分數的模型和去噪擴散概率模型(DDPM)是兩類(lèi)強大的生成模型,它們通過(guò)反轉擴散過(guò)程來(lái)產(chǎn)生樣本。這兩類(lèi)模型已經(jīng)在 Yang Song 等研究者的論文《Score-based generative modeling through stochastic differential equations》中統一到了單一的框架下,并被廣泛地稱(chēng)為擴散模型。
目前,擴散模型在包括圖像、音頻、視頻生成以及解決逆問(wèn)題等一系列應用中取得了巨大的成功。Tero Karras 等研究者在論文《Elucidating the design space of diffusionbased generative models》中對擴散模型的設計空間進(jìn)行了分析,并確定了 3 個(gè)階段,分別為 i) 選擇噪聲水平的調度,ii) 選擇網(wǎng)絡(luò )參數化(每個(gè)參數化生成一個(gè)不同的損失函數),iii) 設計采樣算法。
近日,在谷歌研究院和 UT-Austin 合作的一篇 arXiv 論文《Soft Diffusion: Score Matching for General Corruptions》中,幾位研究者認為擴散模型仍有一個(gè)重要的步驟:損壞(corrupt)。一般來(lái)說(shuō),損壞是一個(gè)添加不同幅度噪聲的過(guò)程,對于 DDMP 還需要重縮放。雖然有人嘗試使用不同的分布來(lái)進(jìn)行擴散,但仍缺乏一個(gè)通用的框架。因此,研究者提出了一個(gè)用于更通用損壞過(guò)程的擴散模型設計框架。
具體地,他們提出了一個(gè)名為 Soft Score Matching 的新訓練目標和一種新穎的采樣方法 Momentum Sampler。理論結果表明,對于滿(mǎn)足正則條件的損壞過(guò)程,Soft Score MatchIng 能夠學(xué)習它們的分數(即似然梯度),擴散必須將任何圖像轉換為具有非零似然的任何圖像。
在實(shí)驗部分,研究者在 CelebA 以及 CIFAR-10 上訓練模型,其中在 CelebA 上訓練的模型實(shí)現了線(xiàn)性擴散模型的 SOTA FID 分數——1.85。同時(shí)與使用原版高斯去噪擴散訓練的模型相比,研究者訓練的模型速度顯著(zhù)更快。
論文地址:https://arxiv.org/pdf/2209.05442.pdf
方法概覽
通常來(lái)說(shuō),擴散模型通過(guò)反轉逐漸增加噪聲的損壞過(guò)程來(lái)生成圖像。研究者展示了如何學(xué)習對涉及線(xiàn)性確定性退化和隨機加性噪聲的擴散進(jìn)行反轉。
具體地,研究者展示了使用更通用損壞模型訓練擴散模型的框架,包含有三個(gè)部分,分別為新的訓練目標 Soft Score Matching、新穎采樣方法 Momentum Sampler 和損壞機制的調度。
首先來(lái)看訓練目標 Soft Score Matching,這個(gè)名字的靈感來(lái)自于軟過(guò)濾,是一種攝影術(shù)語(yǔ),指的是去除精細細節的過(guò)濾器。它以一種可證明的方式學(xué)習常規線(xiàn)性損壞過(guò)程的分數,還在網(wǎng)絡(luò )中合并入了過(guò)濾過(guò)程,并訓練模型來(lái)預測損壞后與擴散觀(guān)察相匹配的圖像。
只要擴散將非零概率指定為任何干凈、損壞的圖像對,則該訓練目標可以證明學(xué)習到了分數。另外,當損壞中存在加性噪聲時(shí),這一條件總是可以得到滿(mǎn)足。
具體地,研究者探究了如下形式的損壞過(guò)程。
在過(guò)程中,研究者發(fā)現噪聲在實(shí)證(即更好的結果)和理論(即為了學(xué)習分數)這兩方面都很重要。這也成為了其與反轉確定性損壞的并發(fā)工作 Cold Diffusion 的關(guān)鍵區別。
其次是采樣方法 Momentum Sampling。研究者證明,采樣器的選擇對生成樣本質(zhì)量具有顯著(zhù)影響。他們提出了 Momentum Sampler,用于反轉通用線(xiàn)性損壞過(guò)程。該采樣器使用了不同擴散水平的損壞的凸組合,并受到了優(yōu)化中動(dòng)量方法的啟發(fā)。
這一采樣方法受到了上文 Yang Song 等人論文提出的擴散模型連續公式化的啟發(fā)。Momentum Sampler 的算法如下所示。
下圖直觀(guān)展示了不同采樣方法對生成樣本質(zhì)量的影響。圖左使用 Naive Sampler 采樣的圖像似乎有重復且缺少細節,而圖右 Momentum Sampler 顯著(zhù)提升了采樣質(zhì)量和 FID 分數。
最后是調度。即使退化的類(lèi)型是預定義的(如模糊),決定在每個(gè)擴散步驟中損壞多少并非易事。研究者提出一個(gè)原則性工具來(lái)指導損壞過(guò)程的設計。為了找到調度,他們將沿路徑分布之間的 Wasserstein 距離最小化。直觀(guān)地講,研究者希望從完全損壞的分布平穩過(guò)渡到干凈的分布。
實(shí)驗結果
研究者在 CelebA-64 和 CIFAR-10 上評估了提出的方法,這兩個(gè)數據集都是圖像生成的標準基線(xiàn)。實(shí)驗的主要目的是了解損壞類(lèi)型的作用。
研究者首先嘗試使用模糊和低幅噪聲進(jìn)行損壞。結果表明,他們提出的模型在 CelebA 上實(shí)現了 SOTA 結果,即 FID 分數為 1.85,超越了所有其他僅添加噪聲以及可能重縮放圖像的方法。此外在 CIFAR-10 上獲得的 FID 分數為 4.64,雖未達到 SOTA 但也具有競爭力。
此外,在 CIFAR-10 和 CelebA 數據集上,研究者的方法在另一項指標采樣時(shí)間上也表現更好。另一個(gè)額外的好處是具有顯著(zhù)的計算優(yōu)勢。與圖像生成去噪方法相比,去模糊(幾乎沒(méi)有噪聲)似乎是一種更有效的操縱。
下圖展示了 FID 分數如何隨著(zhù)函數評估數量(Number of Function Evaluations, NFE)而變。從結果可以看到,在 CIFAR-10 和 CelebA 數據集上,研究者的模型可以使用明顯更少的步驟來(lái)獲得與標準高斯去噪擴散模型相同或更好的質(zhì)量。
*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。