綜述:如何給模型加入先驗知識
作者丨Billy Z@知乎(已授權)
來(lái)源丨h(huán)ttps://zhuanlan.zhihu.com/p/188572028
編輯丨極市平臺
導讀
端到端的深度神經(jīng)網(wǎng)絡(luò )雖然能夠自動(dòng)學(xué)習到一些可區分度好的特征,但是往往會(huì )擬合到一些非重要特征,導致模型會(huì )局部坍塌到一些不好的特征上面。本文通過(guò)一個(gè)簡(jiǎn)單的鳥(niǎo)類(lèi)分類(lèi)案例來(lái)總結了五個(gè)給模型加入先驗信息的方法。
模型加入先驗知識的必要性
端到端的深度神經(jīng)網(wǎng)絡(luò )是個(gè)黑盒子,雖然能夠自動(dòng)學(xué)習到一些可區分度好的特征,但是往往會(huì )擬合到一些非重要特征,導致模型會(huì )局部坍塌到一些不好的特征上面。常常一些人們想讓模型去學(xué)習的特征模型反而沒(méi)有學(xué)習到。
為了解決這個(gè)問(wèn)題,給模型加入人為設計的先驗信息會(huì )讓模型學(xué)習到一些關(guān)鍵的特征。下面就從幾個(gè)方面來(lái)談?wù)勅绾谓o模型加入先驗信息。
為了方便展示,我這邊用一個(gè)簡(jiǎn)單的分類(lèi)案例來(lái)展示如何把先驗知識加入到一個(gè)具體的task中。我們的task是在所有的鳥(niǎo)類(lèi)中識別出一種萌萌的鸚鵡,這中鸚鵡叫鸮(xiāo)鸚鵡,它長(cháng)成下面的樣子:
鸮(xiāo)鸚鵡
這種鳥(niǎo)有個(gè)特點(diǎn):
就是它可能出現在任何地方,但就是不可能在天上,因為它是世界上唯一一種不會(huì )飛的鸚鵡(不是唯一一種不會(huì )飛的鳥(niǎo))。
好,介紹完task的背景,咱們就可以分分鐘搭建一個(gè)端到端的分類(lèi)神經(jīng)網(wǎng)絡(luò ),可以選擇的網(wǎng)絡(luò )結構可以有很多,如resnet, mobilenet等等,loss往往是一個(gè)常用的分類(lèi)Loss,如交叉熵,高級一點(diǎn)的用個(gè)focal loss等等。確定好了最優(yōu)的數據(擾動(dòng)方式),網(wǎng)絡(luò )結構,優(yōu)化器,學(xué)習率等等這些之后,往往模型的精度也就達到了一個(gè)上限。
然后你測試模型發(fā)現,有些困難樣本始終分不開(kāi),或者是一些簡(jiǎn)單的樣本也容易分錯。這個(gè)時(shí)候如果你還想提升網(wǎng)絡(luò )的精度,可以通過(guò)給模型加入先驗的方式來(lái)進(jìn)一步提升模型的精度。
基于pretrain模型給模型加入先驗
給模型加入先驗,大家最容易想到的是把網(wǎng)絡(luò )的weight替換成一個(gè)在另外一個(gè)任務(wù)上pretrain好的模型weight。經(jīng)過(guò)的預訓練的模型(如ImageNet預訓練)往往已經(jīng)具備的識別到一些基本的圖片pattern的能力,如邊緣,紋理,顏色等等,而識別這些信息的能力是識別一副圖片的基礎。如下圖所示:
但這些先驗信息都是一些比較general的信息,我們是否可以加入一些更加high level的先驗信息呢。
基于輸入給模型加入先驗
假如你有這樣的一個(gè)先驗:
你覺(jué)得鸮鸚鵡的頭是一個(gè)區別其他它和鳥(niǎo)類(lèi)的重要部分,也就是說(shuō)相比于身體,它的頭部更能區分它和其他鳥(niǎo)類(lèi)。
這時(shí)怎么讓網(wǎng)絡(luò )更加關(guān)注鸮鸚鵡的頭部呢。這時(shí)你可以這樣做,把整個(gè)鸮鸚鵡和它的頭部作為一個(gè)網(wǎng)絡(luò )的兩路輸入,在網(wǎng)咯的后端再把兩路輸入的信息融合。以達到既關(guān)注局域,又關(guān)注整體的目的。一個(gè)簡(jiǎn)單的示意圖如下所示。
基于模型重現給模型加入先驗
接著(zhù)上面的設定來(lái),假如說(shuō)你覺(jué)得給模型兩路輸入太麻煩,而且增加的計算量讓你感覺(jué)很不爽。
這時(shí),你可以嘗試讓模型自己發(fā)現你設定的先驗知識。
假如說(shuō)你的模型可以自己輸出鳥(niǎo)類(lèi)頭部的位置,雖然這個(gè)鳥(niǎo)類(lèi)頭部的位置信息是你不需要的,但是輸出這樣的信息代表著(zhù)你的網(wǎng)絡(luò )能夠locate鳥(niǎo)類(lèi)頭部的位置,也就給鳥(niǎo)類(lèi)的頭部更加多的attention,也就相當于給把鳥(niǎo)類(lèi)頭部這個(gè)先驗信息給加上去了。
當然直接模仿detection那樣去回歸出位置來(lái)這個(gè)任務(wù)太heavy了,你可以通過(guò)一個(gè)生成網(wǎng)絡(luò )的支路來(lái)生成一個(gè)鳥(niǎo)類(lèi)頭部位置的Mask,一個(gè)簡(jiǎn)單的示意圖如下:
測試的時(shí)候不增加計算量
基于CAM圖激活限制給模型加入先驗
針對鸮鸚鵡的分類(lèi),我在上面的提到一個(gè)非常有意思的先驗信息:
那就是鸮鸚鵡是世界上唯一一種不會(huì )飛的鸚鵡。
這個(gè)信息從側面來(lái)說(shuō)就是,鸮鸚鵡所有地方都可能出現,就是不可能出現在天空中(當然也不可能出現在水中)。
也就是說(shuō)不但鸮鸚鵡本身是一個(gè)分類(lèi)的重點(diǎn),鸮鸚鵡出現的背景也是分類(lèi)的一個(gè)重要參考。假如說(shuō)背景是天空,那么就一定不是鸮鸚鵡,同樣的,假如說(shuō)背景是海水,那么也一定不是鸮鸚鵡,假如說(shuō)背景是北極,那么也一定不是鸮鸚鵡,等等。
也就是說(shuō),你不能通過(guò)背景來(lái)判斷一只未知的鳥(niǎo)是鸮鸚鵡,但是你能通過(guò)背景來(lái)判斷一只未知的鳥(niǎo)肯定不是鸮鸚鵡(是其他的鳥(niǎo)類(lèi))。
所以假如說(shuō)獲取了一張輸入圖片的激活圖(包含背景的),那么這張激活圖的鳥(niǎo)類(lèi)身體部分肯定包含了鸮鸚鵡和其他鳥(niǎo)類(lèi)的激活,但是鳥(niǎo)類(lèi)身體外的背景部分只可能包含其他鳥(niǎo)類(lèi)的激活。
所以具體的做法是基于激活圖,通過(guò)限制激活圖的激活區域,加入目標先驗。
CAM[1]激活圖是基于分類(lèi)網(wǎng)絡(luò )的倒數第二層卷積層的輸出的 feature_map 的線(xiàn)性加權,權重就是最后一層分類(lèi)層的權重,由于分類(lèi)層的權重編碼了類(lèi)別的信息,所以加權后的響應圖就有了基于不同類(lèi)別的區域相應。(具體的介紹可以看 https://zhuanlan.zhihu.com/p/51631163),具體的激活圖生成方式可以如下表示:
說(shuō)了這么多,下面就展示展示激活圖的樣子:
大家可以看到,上面一張是一只鸮鸚鵡的激活圖,下面是一只在天空飛翔的大雁的激活圖。
因為鸮鸚鵡的Label是0,其他鳥(niǎo)類(lèi)的Label是1,所以在激活圖上,只要是負值的激活區域都是鸮鸚鵡的激活,也就是Label為0的激活,只要是正值的激活都是其他鳥(niǎo)類(lèi)的激活,也就是Label為1的激活。
為了方便展示,我把負值的激活用冷色調來(lái)顯示,把正值的激活用暖色調來(lái)顯示,所以就是變成了上面兩幅激活圖的樣子。而右邊的數字是具體的激活矩陣(把激活矩陣進(jìn)行GAP就可以變成最終輸出的Logits)。
到這里不知道大家有沒(méi)有發(fā)現一個(gè)問(wèn)題,就是無(wú)論對于鸮鸚鵡還是大雁的圖片,它們的激活圖除了分布在鳥(niǎo)類(lèi)本身,也會(huì )有一部分分布在背景上。 對于大雁我們好理解,因為大雁是飛在天空中的,而鸮鸚鵡是不可能在天空中的,所以天空的正激活是非常合理的。但是對于鸮鸚鵡來(lái)說(shuō),其在鳥(niǎo)類(lèi)身體以外的負激活就不是太合理,因為,大雁或者是其他的鳥(niǎo)類(lèi),也可能在鸮鸚鵡的地面棲息環(huán)境中(但是鸮鸚鵡卻不可能在天空中)。
所以環(huán)境不能提供任何證據來(lái)證明這一次鳥(niǎo)類(lèi)是一只鸮鸚鵡,鸮鸚鵡的負激活只是在鳥(niǎo)類(lèi)的身體上是合理的。而其他鳥(niǎo)類(lèi)的正激活卻可以同時(shí)在鳥(niǎo)類(lèi)身體上又可能在鳥(niǎo)類(lèi)的背景上(如天空或者海洋)。
所以我們需要這樣建模這個(gè)問(wèn)題,就是在除鳥(niǎo)類(lèi)身體的背景上,不能出現鸮鸚鵡的激活,也就是說(shuō)不能出現負激活(Label為0的激活)。 所以下面的激活才是合理的:
從上面來(lái)看,在除鳥(niǎo)類(lèi)身體外的背景部分是不存在負激活的,雖然上面的背景部分有一些正的激活(其他鳥(niǎo)類(lèi)的激活),但是從右邊的激活矩陣來(lái)看,負激活的scale是占據絕對優(yōu)勢的,所以完全不會(huì )干擾對于鸮鸚鵡的判斷。
所以問(wèn)題來(lái)了,怎么從網(wǎng)絡(luò )設計方面來(lái)達到這個(gè)目的呢?
其實(shí)可以從Loss設計方面來(lái)達到這個(gè)效果。我們假設每一個(gè)鳥(niǎo)都有個(gè)對應的mask,mask內是鳥(niǎo)類(lèi)的身體部分,mask外是鳥(niǎo)類(lèi)的背景部分。那么我們需要做的就是抑制mask外的背景部分激活矩陣的負值,把那一部分負值給抑制到0即可。
鳥(niǎo)類(lèi)的激活矩陣和mask的關(guān)系如下圖(紅色的曲線(xiàn)代表鳥(niǎo)的邊界mask):
我們的Loss設計可以用下面的公式表示:
Loss_cam = -sum(where(bird_mask_outside<0))
具體的網(wǎng)絡(luò )的framework可以如下所示:
其中虛線(xiàn)部分只是訓練時(shí)候需要用到,inference的時(shí)候是不需要的,所以這種方法也是不會(huì )占用任何在inference前向時(shí)候的計算量。
基于輔助學(xué)習給模型加入先驗知識
到現在為止,咱們還只是把我們的鳥(niǎo)類(lèi)分類(lèi)的task當成一個(gè)二分類(lèi)來(lái)處理,即鸮鸚鵡是一類(lèi),其他的鳥(niǎo)類(lèi)是一類(lèi)。
但是我們知道,世界的鳥(niǎo)類(lèi)可不僅僅是兩類(lèi),除了鸮鸚鵡之外還有很多種類(lèi)的鳥(niǎo)類(lèi)。而不同鳥(niǎo)類(lèi)的特征或許有很大的差別,比如鴕鳥(niǎo)的特征就是脖子很長(cháng),大雁的特征就是翅膀很大。
假如只是把鸮鸚鵡當做一類(lèi),把其他的鳥(niǎo)類(lèi)當做一類(lèi)來(lái)學(xué)習的話(huà),那么模型很可能不能學(xué)到可以利用的區分非鸮鸚鵡的特征,或者是會(huì )坍塌到一些區分度不強的特征上面,從而沒(méi)有學(xué)到能夠很好的區分不同其他鳥(niǎo)類(lèi)的特征,而那些特征對去區別鸮鸚鵡和其他鳥(niǎo)類(lèi)或許是重要的。
所以我們有必要加入其他鳥(niǎo)類(lèi)存在不同類(lèi)別的先驗知識。而這里,我主要介紹基于輔助學(xué)習的方式去學(xué)習類(lèi)似的先驗知識。首先我要解釋一下什么是輔助學(xué)習,以及輔助學(xué)習和多任務(wù)學(xué)習的區別:
上圖的左側是多任務(wù)學(xué)習的例子,右側是輔助學(xué)習的例子。左側是個(gè)典型的face attribute的task,意思是輸入一張人臉,通過(guò)多個(gè)branch來(lái)輸出這一張人臉的年齡,性別,發(fā)型等等信息,各個(gè)branch的任務(wù)是獨立的,同時(shí)又共享同一個(gè)backbone。右邊是一個(gè)典型的輔助學(xué)習的task,意思是出入一張人臉,判斷這一張人臉的性別,同時(shí)另外開(kāi)一個(gè)(或幾個(gè))branch,通過(guò)這個(gè)branch來(lái)讓網(wǎng)絡(luò )學(xué)一些輔助信息,比如發(fā)型,皮膚等等,來(lái)幫助網(wǎng)絡(luò )主任務(wù)(分男女)的判別。
好,回到我們的鸮鸚鵡分類(lèi)的task,我們可能首先會(huì )想到下面的Pipeline:
這樣雖然可以把不同類(lèi)別的鳥(niǎo)類(lèi)的特征都學(xué)到,但是卻削弱了網(wǎng)絡(luò )對于鸮鸚鵡和其他鳥(niǎo)類(lèi)特征的分別。
經(jīng)過(guò)實(shí)驗發(fā)現,這種網(wǎng)絡(luò )架構不能很好的增加主任務(wù)的分類(lèi)精度。為了充分的學(xué)到鸮鸚鵡和其他鳥(niǎo)類(lèi)特征的分別,同時(shí)又能帶入不同種類(lèi)鳥(niǎo)類(lèi)類(lèi)別的先驗,我們引入輔助任務(wù):
在上面的Pipeline中,輔助任務(wù)相比如主任務(wù),把其他鳥(niǎo)類(lèi)做更加細致的分類(lèi)。這樣網(wǎng)絡(luò )就學(xué)到了區分不同其他鳥(niǎo)類(lèi)的能力。
但是從實(shí)驗效果來(lái)看這個(gè)Pipeline的精度并不高。經(jīng)過(guò)分析原因,發(fā)現在主任務(wù)和輔助任務(wù)里面都有鸮鸚鵡這一類(lèi),這樣當回傳梯度的時(shí)候,相當于把區分鸮鸚鵡和其他鳥(niǎo)類(lèi)的特征回傳了兩次梯度,而回傳兩次梯度明顯是沒(méi)用的,而且會(huì )干擾輔助任務(wù)學(xué)習不同其他鳥(niǎo)類(lèi)的特征。
所以我們可以把輔助任務(wù)的鸮鸚鵡類(lèi)去除,于是便形成了下面的pipeline:
經(jīng)過(guò)實(shí)驗發(fā)現,這種pipeline是有利于主任務(wù)精度提升的,網(wǎng)絡(luò )對于特征明顯的其他鳥(niǎo)類(lèi)的分類(lèi)能力得到了一定程度的提升,同時(shí)對于困難類(lèi)別的分類(lèi)能力也有一定程度的提升。
當然,輔助任務(wù)的branch可以不只是一類(lèi),你可以通過(guò)多個(gè)類(lèi)別來(lái)定義你的輔助任務(wù)的branch:
這時(shí)候你會(huì )想,上面的pipeline好是好,但是我沒(méi)有那么多的label啊。是的,上面的pipeline除了主任務(wù)的label標注,它還同時(shí)需要很多的輔助任務(wù)的label標注,而標注label是深度學(xué)習任務(wù)里面最讓人頭疼的問(wèn)題(之一)。
別怕,我下面介紹一個(gè)work,它基于meta-learning的方法,讓你不再為給輔助任務(wù)標注label而煩惱,它的framework如下:
這個(gè)framework采用基于maxl[2]的方案(https://github.com/lorenmt/maxl),輔助任務(wù)的數據和label不是由人為手工劃分,而是由一個(gè)label generator來(lái)產(chǎn)生,label generator的優(yōu)化目標是讓主網(wǎng)絡(luò )在主任務(wù)的task上的loss降低,主網(wǎng)絡(luò )的目標是在主任務(wù)和輔助任務(wù)上的loss同時(shí)降低。
但是這個(gè)framework有個(gè)缺點(diǎn),就是訓練時(shí)間會(huì )上升一個(gè)數量級,同時(shí)label generator會(huì )比較難優(yōu)化。感興趣的同學(xué)可以自己嘗試。但是不得不說(shuō),這篇文章有兩個(gè)結論倒是很有意思:
假設 primary 和 auxiliary task 是在同一個(gè) domain,那么 primary task 的 performance 會(huì )提高當且僅當 auxiliary task 的 complexity 高于 primary task。
假設 primary 和 auxiliary task 是在同一個(gè) domain,那么 primary task 的最終 performance 只依賴(lài)于 complexity 最高的 auxiliary task。
結語(yǔ)
先總結一下所有可以有效的加入先驗信息的框架:
你可以通過(guò)上述框架的選擇來(lái)加入自己的先驗信息。
給神經(jīng)網(wǎng)絡(luò )的黑盒子里面加入一些人為設定的先驗知識,這樣往往能給你的task帶來(lái)一定程度的提升,不過(guò)具體的task需要加入什么樣的先驗知識,需要如何加入先驗知識還需要自己探索。
來(lái)自我自己的博客:https://zhengtq.github.io/2020/07/30/pri-knowledge-1/
參考
^CAM https://arxiv.org/abs/1512.04150
^maxl https://arxiv.org/abs/1901.08933
本文僅做學(xué)術(shù)分享,如有侵權,請聯(lián)系刪文。
*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。