大一統視角理解擴散模型Understanding Diffusion Models: A Unified Perspective(2)
在擴散模型里,有幾個(gè)重要的假設。其中一個(gè)就是每一步擴散過(guò)程的變換,都是對前一步結果的高斯變換(上一節M(mǎn)HVAE的限制條件2):
與MHVAE不同,編碼器側的潛在向量分布并不經(jīng)過(guò)學(xué)習得到,而是固定為線(xiàn)性高斯模型
這一點(diǎn)和VAE有很大不同。VAE里編碼器側的潛在向量的分布是通過(guò)模型訓練得到的。而擴散模型里,前向加噪過(guò)程里的每一步都是基于上一步結果的高斯變換。其中 alpha_t 一般當作超參設置得到。這點(diǎn)對于我們計算擴散模型的證據下界有很大幫助。因為我們可以基于輸入x0確切地知道前向過(guò)程里的某一步的具體狀態(tài),從而監督我們的預測。
基于式31,我們可以遞歸式地對x0不斷加噪變換,得到最終xt的表達式:
xt可以寫(xiě)為關(guān)于x0的一個(gè)高斯分布的采樣結果
所以對于式58里噪音匹配項里的監督信號,我們可以重寫(xiě)成以下形式,其中根據式70,我們可以得到q(xt|x0)和q(xt-1|x0)的表達式,而q(xt|xt-1, x0)因為是前向擴散過(guò)程,可以應用馬爾可夫性質(zhì)看做q(xt|xt-1)使用式31得到具體表達式。
式58里的監督信號可以通過(guò)x0計算具體的值
代入每一項q所代表的高斯函數表達式后,我們最后可以得到一個(gè)新的高斯分布表達式,其中每一項都是具體可求的:
q(xt-1|xt,x0)的解析形式
參考已經(jīng)證明了前向加噪過(guò)程可以寫(xiě)為一個(gè)高斯分布了。在擴散模型的初始論文[2]里提到,對于一個(gè)連續的高斯擴散過(guò)程,其逆過(guò)程與前向過(guò)程的方程形式(functional form)一致。所以我們將對去噪匹配項里的p_theta(xt-1|xt)也采用高斯分布的形式(更加具體的一些推導放在了末尾的補充里)。注意式58里,對兩個(gè)高斯分布求KL散度,其解析解的形式如下:
兩個(gè)高斯分布的KL散度解析解
我們現在已知其中一個(gè)高斯分布(左側)的參數,現在如果我們令右側的高斯分布和左側高斯分布的方差保持一致。那么優(yōu)化該KL散度的解析式將簡(jiǎn)化為以下形式:
式58的噪音匹配項簡(jiǎn)化為最小化前后向均值的預測誤差
如此一來(lái)式58的噪音匹配項就被簡(jiǎn)化為最小化前后向均值的預測誤差(式92)。讀者請注意,以下的大一統的三個(gè)角度來(lái)看待Diffusion model,實(shí)質(zhì)上都是對式92里mu_q的不同變形所推論出來(lái)的。 其中mu_q是關(guān)于xt, x0的函數,而mu_theta是關(guān)于xt和t的函數。其中通過(guò)式84,我們有mu_q的準確計算結果,而因為mu_theta是關(guān)于xt的函數。我們可以將其寫(xiě)為類(lèi)似式84的形式(注意,有關(guān)為什么可以忽略方差并且讓均值選取這個(gè)形式放在了最末尾的補充討論里。但關(guān)于這個(gè)形式的選擇的深層原因實(shí)質(zhì)上開(kāi)辟了一個(gè)全新的領(lǐng)域來(lái)研究,并且關(guān)于該領(lǐng)域的研究直接導向了擴散模型之后的一系列加速采樣技術(shù)的出現)
將后向預測的均值寫(xiě)為類(lèi)似前向加噪的形式
比較式84與94可知,x_hat是我們通過(guò)噪音數據xt來(lái)預測原始數據x0的神經(jīng)網(wǎng)絡(luò )。那么我們可以將式58里證據下界的噪音匹配項,最終寫(xiě)為
噪聲匹配項的最終形式
那么,我們最后得到擴散模型的優(yōu)化,最終表現為訓練一個(gè)神經(jīng)網(wǎng)絡(luò ),以任意時(shí)間步的噪音圖像為輸入,來(lái)預測最初的原始圖像!此時(shí)優(yōu)化目標轉化為了最小化預測誤差。同時(shí)式58上的對所有時(shí)間步的噪音匹配項求和的優(yōu)化,可以近似為對每一時(shí)間步上的預測誤差的期望的最小值,而該優(yōu)化目標可以通過(guò)隨機采樣近似:
該優(yōu)化目標可以通過(guò)隨機采樣實(shí)現
為什么Calvin Luo的這篇論文叫做大一統視角來(lái)看待擴散模型?以上我們花了不菲的篇幅論證了擴散模型的優(yōu)化目標可以最終轉化為訓練一個(gè)神經(jīng)網(wǎng)絡(luò )在任意時(shí)間步從xt預測原始輸入x0。以下我們將論述如何通過(guò)對mu_q不同的推導得到類(lèi)似的角度看待擴散模型。
首先,我們已經(jīng)知道給定每個(gè)時(shí)間步的噪聲系數alpha_t之后,我們可以由初始輸入x0遞歸得到xt。同理,給定xt我們也可以求得x0。那么對式69重置后,我們可以得到式115.
將式69里的xt和x0關(guān)系重置后可得式115
重新將式115代入式84里,我們所得的關(guān)于時(shí)間步t的真實(shí)均值表達式mu_q后,我們可以得到以下推導:
在推導真實(shí)均值時(shí)替換x0
注意在上一次推導的過(guò)程中,mu_q里的xt在計算kl散度的解析式時(shí)被抵消掉了,而x0我們采取的是用神經(jīng)網(wǎng)絡(luò )直接擬合的策略。而在這一次的推導過(guò)程中,x0被替換成了關(guān)于xt的表達式(關(guān)于alpha_bar和epsilon_0)后,我們可以得到mu_q的新的表達式,依舊關(guān)于xt,只是不再與x0相關(guān),而是與epsilon_0相關(guān)(式124)。其中,和式94一樣,我們忽略方差(將其設為與前向一致)并將希望擬合的mu_theta寫(xiě)成與真實(shí)均值mu_q一樣的形式,只是將epsilon_0替換為神經(jīng)網(wǎng)絡(luò )的擬合項后我們可以得到式125。
與上次推導時(shí)替換x0為神經(jīng)網(wǎng)絡(luò )所擬合項一樣,這次換為擬合初始噪聲項
將我們新得到的兩個(gè)均值表達式重新代入KL散度的表達式里,xt再次被抵消掉(因為mu_theta和mu_q選取的形式一致)最終只剩下epsilon_0和epsilon_theta的差值。注意式130和式99的相似性!
最終對證據下界里的去噪匹配項的優(yōu)化可以寫(xiě)成關(guān)于初始噪聲和其擬合項的差的最小化
至此,我們得到了對擴散模型的第二種直觀(guān)理解。對于一個(gè)變分擴散模型VDM,我們優(yōu)化該模型的證據下界既等價(jià)于優(yōu)化其在所有時(shí)間步上對初始圖像的預測誤差的期望,也等價(jià)于優(yōu)化在所有時(shí)間步上對噪聲的預測誤差的期望! 事實(shí)上DDPM采取的做法就是式130的做法(注意DDPM里的表達式實(shí)際上用的是epsilon_t,關(guān)于這點(diǎn)在文末也會(huì )討論)。
下面筆者將概括第三種看待VDM的推導方式。這種方式主要來(lái)自于SongYang博士的系列論文,非常直觀(guān)。并且該系列論文將擴散模型這種離散的多步去噪過(guò)程統一成了一個(gè)連續的隨機微分方程(SDE)的特殊形式。SongYang博士因此獲得了ICLR2021的最佳論文獎!后續來(lái)自清華大學(xué)的基于將該SDE轉化為常微分方程O(píng)DE后的采樣提速論文,也獲得了2022ICLR的最佳論文獎!關(guān)于該論文的一些細節和直觀(guān)理解,SongYang博士在他自己的博客里給出了非常精彩和直觀(guān)的講解。有興趣的讀者可以點(diǎn)開(kāi)本文初始的第二個(gè)鏈接查看。以下只對大一統視角下的第三種視角做簡(jiǎn)短的概括。
第三種推導方式主要基于Tweedie's formula.該公式主要闡述了對于一個(gè)指數家族的分布的真實(shí)均值,在給定了采樣樣本后,可以通過(guò)采樣樣本的最大似然概率(即經(jīng)驗均值)加上一個(gè)關(guān)于分數(score)預估的校正項來(lái)預估。注意score在這里的定義是真實(shí)數據分布的對數似然關(guān)于輸入xt的梯度。即
score的定義
根據Tweedie's formula,對于一個(gè)高斯變量z~N(mu_z, sigma_z)來(lái)說(shuō), 該高斯變量的真實(shí)均值的預估是:
Tweedie's formula對高斯變量的應用
我們知道在訓練時(shí),模型的輸入xt關(guān)于x0的表達式如下
上文里的式70
我們也知道根據Tweedie's formula的高斯變量的真實(shí)均值預估我們可以得到下式
將式70的方差代入Tweedie's formula
那么聯(lián)立兩式的關(guān)于均值的表達式后,我們可以得到x0關(guān)于score的表達式133
將x0寫(xiě)為關(guān)于score的表達式
如上一種推導方式所做的一樣,再一次重新將x0的表達式代入式84對真實(shí)均值mu_q的表達式里:(注意式135到136的變形主要在分子里最右邊的alpha_bar_t到alpha_t, 約去了根號下alpha_bar_t-1)
將x0的關(guān)于score表達式代入式84
同樣,將mu_theta采取和mu_q一樣的形式,并用神經(jīng)網(wǎng)絡(luò )s_theta來(lái)近似score后, 我們得到了新的mu_theta的表達式143。
關(guān)于score的mu_theta的表達式
再再再同樣,和上種推導里的做法一樣,我們再將新的mu_theta, mu_q代入證據下界里KL散度的損失項我們可以得到一個(gè)最終的優(yōu)化目標
將新的mu的表達式代入證據下界的優(yōu)化目標里
事實(shí)上,比較式148和式130的形式,可以說(shuō)是非常的接近了。那么我們的score function delta_p(xt)和初始噪聲epsilon_0是否有關(guān)聯(lián)呢?聯(lián)立關(guān)于x0的兩個(gè)表達式133和115我們可以得到
score function和初始噪聲間的關(guān)系
讀者如果將式151代入148會(huì )發(fā)現和式130等價(jià)!直觀(guān)上來(lái)講,score function描述的是如何在數據空間里最大化似然概率的更新向量。而又因為初始噪聲是在原輸入的基礎上加入的,那么往噪聲的反方向(也是最佳方向)更新實(shí)質(zhì)上等價(jià)于去噪的過(guò)程。而數學(xué)上講,對score function的建模也等價(jià)于對初始噪聲乘上負系數的建模!
至此我們終于將擴散模型的三個(gè)形式的所有推導整理完畢!即對變分擴散模型VDM的訓練等價(jià)于訓練一個(gè)神經(jīng)網(wǎng)絡(luò )來(lái)預測原輸入x0,也等價(jià)于預測噪聲epsilon, 也等價(jià)于預測初始輸入在特定時(shí)間步的score delta_logp(xt)。
讀到這里,相比讀者也已經(jīng)發(fā)現,不同的推導所得出的不同結果,都來(lái)自于對證據下界里去噪匹配項的不同推導過(guò)程。而不同的變形,基本上都是利用了MHVAE里最開(kāi)始提到的三點(diǎn)基本假設所得。
Drawbacks to Consider盡管擴散模型在最近兩年成功出圈,引爆了業(yè)界,學(xué)術(shù)界甚至普通人對文本生成圖像的AI模型的關(guān)注,但擴散模型這個(gè)體系本身依舊存在著(zhù)一些缺陷:
- 擴散模型本身盡管理論框架已經(jīng)比較完善,公式推導也十分優(yōu)美。但仍然非常不直觀(guān)。最起碼從一個(gè)完全噪聲的輸入不斷優(yōu)化的這個(gè)過(guò)程和人類(lèi)的思維過(guò)程相去甚遠。
- 擴散模型和GAN或者VAE相比,所學(xué)的潛在向量不具備任何語(yǔ)義和結構的可解釋性。上文提到了擴散模型可以看做是特殊的MHVAE,但里面每一層的潛在向量間都是線(xiàn)性高斯的形式,變化有限。
- 而擴散模型的潛在向量要求維度與輸入一致這一點(diǎn),則更加死地限制住了潛在向量的表征能力。
- 擴散模型的多步迭代導致了擴散模型的生成往往耗時(shí)良久。
不過(guò)學(xué)術(shù)界對以上的一些難題其實(shí)也提出了不少解決方案。比如擴散模型的可解釋性問(wèn)題。筆者最近就發(fā)現了一些工作將score-matching直接應用在了普通VAE的潛在向量的采樣上。這是一個(gè)非常自然的創(chuàng )新點(diǎn),就和數年前的flow-based-vae一樣。而耗時(shí)良久的問(wèn)題,今年ICLR的最佳論文也將采樣這個(gè)問(wèn)題加速和壓縮到了幾十步內就可以生成非常高質(zhì)量的結果。
但是對于擴散模型在文本生成領(lǐng)域的應用最近似乎還不多,除了prefix-tuning的作者xiang-lisa-li的一篇論文[3]
之外筆者暫未關(guān)注到任何工作。而具體來(lái)講,如果將擴散模型直接用在文本生成上,仍有諸多不便。比如輸入的尺寸在整個(gè)擴散過(guò)程必須保持一致就決定了使用者必須事先決定好想生成的文本的長(cháng)度。而且做有引導的條件生成還好,要用擴散模型訓練出一個(gè)開(kāi)放域的文本生成模型恐怕難度不低。
本篇筆記著(zhù)重的是在探討大一統角度下的擴散模型推斷。但具體對score matching如何訓練,如何引導擴散模型生成我們想要的條件分布還沒(méi)有寫(xiě)出來(lái)。筆者打算在下一篇探討最近一些將擴散模型應用在受控文本生成領(lǐng)域的方法調研里詳細記錄和比較一下
補充- 關(guān)于為什么擴散核是高斯變換的擴散過(guò)程的逆過(guò)程也是高斯變換的問(wèn)題,來(lái)自清華大神的一篇知乎回答里[4] 給出了比較直觀(guān)的解釋。其中第二行是將p_t-1和p_t近似。第三行是對logpt(x_t-1)使用一階泰勒展開(kāi)消去了logpt(xt)。第四行是直接代入了q(xt|xt-1)的表達式。于是我們得到了一個(gè)高斯分布的表達式。
擴散的逆過(guò)程也是高斯分布
- 在式94和式125,我們都將對真實(shí)高斯分布q的均值mu_q的近似mu_theta建模成了與我們所推導出的mu_q一致的形式,并且將方差設置為了與q的方差一致的形式。直觀(guān)上來(lái)講,這樣建模的好處很多,一方面是根據KL散度對兩個(gè)高斯分布的解析式來(lái)說(shuō),這樣我們可以約掉和抵消掉絕大部分的項,簡(jiǎn)化了建模。另一方面真實(shí)分布和近似分布都依賴(lài)于xt。在訓練時(shí)我們的輸入就是xt,采取和真實(shí)分布形式一樣的表達式?jīng)]有泄漏任何信息。并且在工程上DDPM也驗證了類(lèi)似的簡(jiǎn)化是事實(shí)上可行的。但實(shí)際上可以這樣做的原因背后是從2021年以來(lái)的一系列論文里復雜的數理證明所在解釋的目標。 同樣引用清華大佬[4]的回答:
DDPM里簡(jiǎn)化去噪的高斯分布的做法其實(shí)蘊含著(zhù)深刻的道理
- 在DDPM里,其最終的優(yōu)化目標是epsilon_t而不是epsilon_0。即預測的誤差到底是初始誤差還是某個(gè)時(shí)間步上的初始誤差。誰(shuí)對誰(shuí)錯?實(shí)際上這個(gè)誤解來(lái)源于我們對xt關(guān)于x0的表達式的求解中的誤解。從式63開(kāi)始的連續幾步推導,都應用到了一個(gè)高斯性質(zhì),即兩個(gè)獨立高斯分布的和的均值與方差等于原分布的均值和與方差和。而實(shí)質(zhì)上我們在應用重參數化技巧求xt的過(guò)程中,是遞歸式的不斷引入了新的epsilon來(lái)替換遞歸中的x_n里的epsilon。那么到最后,我們所得到的epsilon無(wú)非是一個(gè)囊括了所有擴散過(guò)程中的epsilon。這個(gè)噪聲即可以說(shuō)是t,也可以說(shuō)是0,甚至最準確來(lái)說(shuō)應該不等于任何一個(gè)時(shí)間步,就叫做噪聲就好!
DDPM的優(yōu)化目標
- 關(guān)于對證據下界的不同簡(jiǎn)化形式。其中我們提到第二種對噪聲的近似是DDPM所采用的建模方式。但是對初始輸入的近似其實(shí)也有論文采用。也就是上文提及的將擴散模型應用在可控文本生成的論文里[3]所采用的形式。該論文每輪直接預測初始Word-embedding。而第三種score-matching的角度可以參照SongYang博士的系列論文[5]來(lái)看。里面的優(yōu)化函數的形式用的是第三種。
- 本篇筆記著(zhù)重于講述擴散模型的變分下界的公式推導,關(guān)于擴散模型與能量模型,朗之萬(wàn)動(dòng)力學(xué),隨機微分方程等一系列名詞的關(guān)系本篇筆記并無(wú)涉及。 筆者將在另外一篇筆記里梳理相關(guān)的理解。
參考
- ^Improving Variational Inference with Inverse Autoregressive Flow https://arxiv.org/abs/1606.04934
- ^Deep Unsupervised Learning using Nonequilibrium Thermodynamics https://arxiv.org/abs/1503.03585
- ^abDiffusion-LM Improves Controllable Text Generation https://arxiv.org/abs/2205.14217
- ^abdiffusion model最近在圖像生成領(lǐng)域大紅大紫,如何看待它的風(fēng)頭開(kāi)始超過(guò)GAN?- 我想唱high C的回答 - 知乎 https://www.zhihu.com/question/536012286/answer/2533146567
- ^SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS https://arxiv.org/abs/2011.13456
*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。