大一統視角理解擴散模型Understanding Diffusion Models: A Unified Perspective(1)
這篇文章是近期筆者閱讀擴散模型的一些技術(shù)博客和概覽的一篇梳理。主要參考的內容來(lái)自Calvin luo的論文,針對的對象主要是對擴散模型已經(jīng)有一些基礎了解的讀者。Calvin luo 的這篇論文為理解擴散模型提供了一個(gè)統一的視角,尤其是其中的數理公式推導非常詳盡,本文將試圖盡量簡(jiǎn)要地概括一遍大一統視角下的擴散模型的推導過(guò)程。在結尾處,筆者附上了一些推導過(guò)程中的強假設的思考和疑惑,并簡(jiǎn)要討論了下擴散模型應用在自然語(yǔ)言處理時(shí)的一些思考。
本篇閱讀筆記一共參考了以下技術(shù)博客。其中如果不了解擴散模型的讀者可以考慮先閱讀lilian-weng的科普博客。Calvin-Luo的這篇介紹性論文在書(shū)寫(xiě)的時(shí)候經(jīng)過(guò)了包括Jonathan Ho(DDPM作者), SongYang博士 和一系列相關(guān)擴散模型論文的發(fā)表者的審核,非常值得一讀。
1. What are Diffusion Models? by Lilian Weng
2. Generative Modeling by Estimating Gradients of the Data Distribution by Song Yang
3. Understanding Diffusion Models: A Unified Perspective by Calvin Luo
生成模型希望可以生成符合真實(shí)分布(或給定數據集)的數據。我們常見(jiàn)的幾種生成模型有GANs,Flow-based Models, VAEs, Energy-Based Models 以及我們今天希望討論的擴散模型Diffusion Models. 其中擴散模型和變分自編碼器VAEs, 和基于能量的模型EBMs有一些聯(lián)系和區別,筆者會(huì )在接下來(lái)的章節闡述。
常見(jiàn)的幾種生成模型
在介紹擴散模型前,我們先來(lái)回顧一下變分自編碼器VAE。我們知道VAE最大的特點(diǎn)是引入了一個(gè)潛在向量的分布來(lái)輔助建模真實(shí)的數據分布。那么為什么我們要引入潛在向量?有兩個(gè)直觀(guān)的原因,一個(gè)是直接建模高維表征十分困難,常常需要引入很強的先驗假設并且有維度詛咒的問(wèn)題存在。另外一個(gè)是直接學(xué)習低維的潛在向量,一方面起到了維度壓縮的作用,一方面也希望能夠在低維空間上探索具有語(yǔ)義化的結構信息(例如圖像領(lǐng)域里的GAN往往可以通過(guò)操控具體的某個(gè)維度影響輸出圖像的某個(gè)具體特征)。
引入了潛在向量后,我們可以將我們的目標分布的對數似然logP(x),也稱(chēng)為“證據evidence“寫(xiě)成下列形式:
ELBO的推理過(guò)程
其中,我們重點(diǎn)關(guān)注式15. 等式的左邊是生成模型想要接近的真實(shí)數據分布(evidence),等式右邊由兩項組成,其中第二項的KL散度因為恒大于零,所以不等式恒成立。如果在等式右邊減去該KL散度,則我們得到了真實(shí)數據分布的下界,即證據下界ELBO。對ELBO進(jìn)行進(jìn)一步的展開(kāi),我們就可以得到VAE的優(yōu)化目標
ELBO等式的展開(kāi)
對該證據下界的變形的形式,我們可以直觀(guān)地這么理解:證據下界等價(jià)于這么一個(gè)過(guò)程,我們用編碼器將輸入x編碼為一個(gè)后驗的潛在向量分布q(z|x)。我們希望這個(gè)向量分布盡可能地和真實(shí)的潛在向量分布p(z)相似,所以用KL散度約束,這也可以避免學(xué)習到的后驗分布q(z|x)坍塌成一個(gè)狄拉克delta函數(式19的右側)。而得到的潛在向量我們用一個(gè)****重構出原數據,對應的是式19的左邊P(x|z)。
VAE為什么叫變分自編碼器。變分的部分來(lái)自于尋找最優(yōu)的潛在向量分布q(z|x)的這個(gè)過(guò)程。自編碼器的部分是上面提到的對輸入數據的編碼,再解碼為原數據的行為。
那么提煉一下為什么VAE可以比較好地貼合原數據的分布?因為根據上述的公式推導我們發(fā)現:原數據分布的對數似然(稱(chēng)為證據evidence)可以寫(xiě)成證據下界加上我們希望近似的后驗潛在向量分布和真實(shí)的潛在向量分布間的KL散度(即式15)。如果把該式寫(xiě)為A = B+C的形式。因為evidence(即A)是個(gè)常數(與我們要學(xué)習的參數無(wú)關(guān)),所以最大化B,也就是我們的證據下界,等價(jià)于最小化C,也即是我們希望擬合的分布和真實(shí)分布間的差別。而因為證據下界,我們可以重新寫(xiě)成式19那樣一個(gè)自編碼器的形式,我們也就得到了自編碼器的訓練目標。優(yōu)化該目標,等價(jià)于近似真實(shí)數據分布,也等價(jià)于用變分手法來(lái)優(yōu)化后驗潛在向量分布q(z|x)的過(guò)程。
但VAE自身依然有很多問(wèn)題。一個(gè)最明顯的就是我們如何選定后驗分布q_phi(z|x)。絕大多數的VAE實(shí)現里,這個(gè)后驗分布被選定為了一個(gè)多維高斯分布。但這個(gè)選擇更多的是為了計算和優(yōu)化的方便而選擇。這樣的簡(jiǎn)單形式極大地限制了模型逼近真實(shí)后驗分布的能力。VAE的原作者kingma曾經(jīng)有篇非常經(jīng)典的工作就是通過(guò)引入normalization flow[1]在改進(jìn)后驗分布的表達能力。而擴散模型同樣可以看做是對后驗分布q_phi(z|x)的改進(jìn)。
Hierarchical VAE下圖展示了一個(gè)變分自編碼器里,潛在向量和輸入間的閉環(huán)關(guān)系。即從輸入中提取低維的潛在向量后,我們可以通過(guò)這個(gè)潛在向量重構出輸入。
VAE里潛在向量與輸入的關(guān)系
很明顯,我們認為這個(gè)低維的潛在向量里一定是高效地編碼了原數據分布的一些重要特性,才使得我們的****可以成功重構出原數據分布里的各式數據。那么如果我們遞歸式地對這個(gè)潛在向量再次計算“潛在向量的潛在向量”,我們就得到了一個(gè)多層的HVAE,其中每一層的潛在向量條件于所有前序的潛在向量。但是在這篇文章里,我們主要關(guān)注具有馬爾可夫性質(zhì)的層級變分自編碼器MHVAE,即每一層的潛在向量?jì)H條件于前一層的潛在向量。
MHVAE里的潛在向量只條件于上一層
對于該MHVAE,我們可以通過(guò)馬爾可夫假設得到以下二式
23和24式是用鏈式法則對依賴(lài)圖里的關(guān)系的拆解
對于該MHVAE,我們可以用以下步驟推導其證據下界
MHVAE的變分下界推導
我們之所以在談?wù)摂U散模型之前,要花如此大的篇幅介紹VAE,并引出MHVAE的證據下界推導是因為我們可以非常自然地將擴散模型視為一種特殊的MHVAE,該MHVAE滿(mǎn)足以下三點(diǎn)限制(注意以下三點(diǎn)限制也是整個(gè)擴散模型推斷的基礎):
- 潛在向量Z的維度和輸入X的維度保持一致。
- 每一個(gè)時(shí)間步的潛在向量都被編碼為一個(gè)僅依賴(lài)于上一個(gè)時(shí)間步的潛在向量的高斯分布。
- 每一個(gè)時(shí)間步的潛在向量的高斯分布的參數,隨時(shí)間步變化,且滿(mǎn)足最終時(shí)間步的高斯分布滿(mǎn)足標準高斯分布的限制。
因為第一點(diǎn)維度一致的原因,在不影響理解的基礎上,我們將MHVAE里的Zt表示為Xt(其中x0為原始輸入),則我們可以將MHVAE的層級潛在向量依賴(lài)圖,重新畫(huà)為以下形式(即將擴散模型的中間擴散過(guò)程當做潛在向量的層級建模過(guò)程):
擴散過(guò)程的直觀(guān)解釋?zhuān)涸跀祿0上不斷加高斯噪聲直至退化為純噪聲圖像Xt
直至這里,我們終于見(jiàn)到了我們熟悉的擴散模型的形式。
而在將上面的公式25-28里的Zt與Xt替換后,我們可以得到VDM里證據下界的推導公式里的前四行,即公式34-37。并且在此基礎上,我們可以繼續往下推導。37至38行的變換是鏈式法則的等價(jià)替換(或上述公式23和24的變換),38至39行是連乘過(guò)程的重組,39至40行是對齊連乘符號的區間,40至41行應用了Log乘法的性質(zhì),41至42繼續運用該性質(zhì)進(jìn)一步拆分,42至43行是因為和的期望等于期望的和,43至44是因為期望目標與部分時(shí)間步的概率無(wú)關(guān)可以直接省去,44至45步是應用了KL散度的定義進(jìn)行了重組。
VDM的證據下界推導
至此,我們又一次將原數據分布的對數似然,轉化為了證據下界(公式37),并將其轉化為了幾項非常直觀(guān)的損失函數的加和形式(公式45),他們分別為:
- 重構項,即從潛在向量x1到原數據x0的變化。在VAE里該重構項寫(xiě)為logP(x|z),而在這里我們寫(xiě)做logP(x0|x1)
- 先驗匹配項?;貞浳覀兩鲜鎏岬降腗HVAE里最終時(shí)間步的高斯分布應建立為標準高斯分布
- 一致項。該項損失是為了使得前向加噪過(guò)程和后向去噪的過(guò)程中,Xt的分布保持一致。直觀(guān)上講,對一個(gè)更混亂圖像的去噪應一致于對一個(gè)更清晰的圖像的加噪。而因為一致項的損失是定義于所有時(shí)間步上的,這也是三項損失里最耗時(shí)計算的一項。
雖然以上的公式推導給了我們一個(gè)非常直觀(guān)的證據下界,并且由于每一項都是以期望來(lái)計算,所以天然適用蒙特卡洛方法來(lái)近似,但如果優(yōu)化該證據下界依然存在幾個(gè)問(wèn)題:
- 我們的一致項損失是一項建立在兩個(gè)隨機變量(Xt-1, Xt+1)上的期望。他們的蒙特卡洛估計的方差大概率比建立在單個(gè)獨立變量上的蒙特卡洛估計的方差大。
- 我們的一致項是定義于所有時(shí)間步上的KL散度的期望和。對于T取值較高的情況(通常擴散模型T取2000左右),該期望的方差也會(huì )很大。
所以我們需要重新推導一個(gè)證據下界。而這個(gè)推導的關(guān)鍵將著(zhù)眼于以下這個(gè)觀(guān)察:我們可以將擴散過(guò)程的正向加噪過(guò)程q(xt|xt-1)重寫(xiě)為q(xt|xt-1, x0)。之所以這樣重寫(xiě)的原因是基于馬爾可夫假設,這兩個(gè)式子完全等價(jià)。于是對這個(gè)式子使用貝葉斯法則,我們可以得到式46.
對前向加噪過(guò)程使用馬爾可夫假設和貝葉斯法則后的公式
基于公式46,我們可以重寫(xiě)上面的證據下界(式37)為以下形式:其中式47,48和式37,38一致。式49開(kāi)始,分母的連乘拆解由從T開(kāi)始改為從1開(kāi)始。式50基于上文提及的馬爾可夫假設對分母添加了x0的依賴(lài)。式51用log的性質(zhì)拆分了對數的目標。式52代入了式46做了替換。式53將劃掉的分母部分連乘單獨提取出來(lái)后發(fā)現各項可約剩下式54部分的log(q(x1|x0)/q(xT|x0))。式54用log的性質(zhì)消去了q(x1|x0)得到了式55。式56用log的性質(zhì)拆分重組了公式,式57如同前述式43-44的變換,省去了無(wú)關(guān)的時(shí)間步。式58則用了KL散度的性質(zhì)。
應用了馬爾可夫假設的擴散模型證據下界推導1
應用了馬爾可夫假設的擴散模型證據下界推導2
至此,我們應用了馬爾可夫假設得到了一個(gè)更優(yōu)的證據下界推導。該證據下界同樣包含幾項直觀(guān)的損失函數:
- 重構項。該重構項與上面提及的重構項一致。
- 先驗匹配項。與上面提及的形式略有差別,但同樣是基于最終時(shí)間步應為標準高斯的先驗假設
- 去噪匹配項。與上面提及的一致項的最大區別在于不再是對兩個(gè)隨機變量的期望。并且直觀(guān)上理解p(xt-1|xt)代表的是后向的去噪過(guò)程,而q(xt-1|xt, x0)代表的是已知原始圖像和目標噪聲圖像的前向加噪過(guò)程。該加噪過(guò)程作為目標信號,來(lái)監督后向的去噪過(guò)程。該項解決了期望建立于兩個(gè)隨機變量上的問(wèn)題。
注意,以上的推導完全基于馬爾可夫的性質(zhì)所以適用于所有MHVAE,所以當T=1的時(shí)候,以上的證據下界和VAE所推導出的證據下界完全一致!并且本文之所以稱(chēng)為大一統視角,是因為對于該證據下界里的去噪匹配項,不同的論文有不同的優(yōu)化方式。但歸根結底,他們的本質(zhì)互相等價(jià),且皆由該式展開(kāi)推導得到。下面我們會(huì )從擴散模型的角度做公式推導,來(lái)展開(kāi)計算去噪匹配項。(注意第一版的推導里的一致項,也完全可以通過(guò)下一節的方式得到q和p的表達式,再通過(guò)KL來(lái)計算解析式)
*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。
物聯(lián)網(wǎng)相關(guān)文章:物聯(lián)網(wǎng)是什么