<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>
"); //-->

博客專(zhuān)欄

EEPW首頁(yè) > 博客 > 知識蒸餾綜述:代碼整理(1)

知識蒸餾綜述:代碼整理(1)

發(fā)布人:計算機視覺(jué)工坊 時(shí)間:2022-01-16 來(lái)源:工程師 發(fā)布文章

作者 | PPRP 

來(lái)源 | GiantPandaCV

編輯 | 極市平臺

導讀

本文收集自RepDistiller中的蒸餾方法,盡可能簡(jiǎn)單解釋蒸餾用到的策略,并提供了實(shí)現源碼。

1. KD: Knowledge Distillation

全稱(chēng):Distilling the Knowledge in a Neural Network

鏈接:https://arxiv.org/pdf/1503.02531.pd3f

發(fā)表:NIPS14

最經(jīng)典的,也是明確提出知識蒸餾概念的工作,通過(guò)使用帶溫度的softmax函數來(lái)軟化教師網(wǎng)絡(luò )的邏輯層輸出作為學(xué)生網(wǎng)絡(luò )的監督信息,

使用KL divergence來(lái)衡量學(xué)生網(wǎng)絡(luò )與教師網(wǎng)絡(luò )的差異,具體流程如下圖所示(來(lái)自Knowledge Distillation A Survey)

1.jpg

對學(xué)生網(wǎng)絡(luò )來(lái)說(shuō),一部分監督信息來(lái)自hard label標簽,另一部分來(lái)自教師網(wǎng)絡(luò )提供的soft label。代碼實(shí)現:

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T
    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss

核心就是一個(gè)kl_div函數,用于計算學(xué)生網(wǎng)絡(luò )和教師網(wǎng)絡(luò )的分布差異。

2. FitNet: Hints for thin deep nets

全稱(chēng):Fitnets: hints for thin deep nets

鏈接:https://arxiv.org/pdf/1412.6550.pdf

發(fā)表:ICLR 15 Poster

對中間層進(jìn)行蒸餾的開(kāi)山之作,通過(guò)將學(xué)生網(wǎng)絡(luò )的feature map擴展到與教師網(wǎng)絡(luò )的feature map相同尺寸以后,使用均方誤差MSE Loss來(lái)衡量?jì)烧卟町悺?/p>

2.jpg

實(shí)現如下:

class HintLoss(nn.Module):
    """Fitnets: hints for thin deep nets, ICLR 2015"""
    def __init__(self):
        super(HintLoss, self).__init__()
        self.crit = nn.MSELoss()
    def forward(self, f_s, f_t):
        loss = self.crit(f_s, f_t)
        return loss

實(shí)現核心就是MSELoss。

3. AT: Attention Transfer

全稱(chēng):Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

鏈接:https://arxiv.org/pdf/1612.03928.pdf

發(fā)表:ICLR16

為了提升學(xué)生模型性能提出使用注意力作為知識載體進(jìn)行遷移,文中提到了兩種注意力,一種是activation-based attention transfer,另一種是gradient-based attention transfer。實(shí)驗發(fā)現第一種方法既簡(jiǎn)單效果又好。

3.jpg

實(shí)現如下:

class Attention(nn.Module):
    """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
    via Attention Transfer
    code: https://github.com/szagoruyko/attention-transfer"""
    def __init__(self, p=2):
        super(Attention, self).__init__()
        self.p = p
    def forward(self, g_s, g_t):
        return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
    def at_loss(self, f_s, f_t):
        s_H, t_H = f_s.shape[2], f_t.shape[2]
        if s_H > t_H:
            f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
        elif s_H < t_H:
            f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
        else:
            pass
        return (self.at(f_s) - self.at(f_t)).pow(2).mean()
    def at(self, f):
        return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

首先使用avgpool將尺寸調整一致,然后使用MSE Loss來(lái)衡量?jì)烧卟罹唷?/p>

4. SP: Similarity-Preserving

全稱(chēng):Similarity-Preserving Knowledge Distillation

鏈接:https://arxiv.org/pdf/1907.09682.pdf

發(fā)表:ICCV19SP

歸屬于基于關(guān)系的知識蒸餾方法。文章思想是提出相似性保留的知識,使得教師網(wǎng)絡(luò )和學(xué)生網(wǎng)絡(luò )會(huì )對相同的樣本產(chǎn)生相似的激活??梢詮南聢D看出處理流程,教師網(wǎng)絡(luò )和學(xué)生網(wǎng)絡(luò )對應feature map通過(guò)計算內積,得到bsxbs的相似度矩陣,然后使用均方誤差來(lái)衡量?jì)蓚€(gè)相似度矩陣。

4.jpg

最終Loss為:

G代表的就是bsxbs的矩陣。實(shí)現如下:

class Similarity(nn.Module):
    """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
    def __init__(self):
        super(Similarity, self).__init__()
    def forward(self, g_s, g_t):
        return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
    def similarity_loss(self, f_s, f_t):
        bsz = f_s.shape[0]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)
        G_s = torch.mm(f_s, torch.t(f_s))
        # G_s = G_s / G_s.norm(2)
        G_s = torch.nn.functional.normalize(G_s)
        G_t = torch.mm(f_t, torch.t(f_t))
        # G_t = G_t / G_t.norm(2)
        G_t = torch.nn.functional.normalize(G_t)
        G_diff = G_t - G_s
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
        return loss

5. CC: Correlation Congruence

全稱(chēng):Correlation Congruence for Knowledge Distillation

鏈接:https://arxiv.org/pdf/1904.01802.pdf

發(fā)表:ICCV19

CC也歸屬于基于關(guān)系的知識蒸餾方法。不應該僅僅引導教師網(wǎng)絡(luò )和學(xué)生網(wǎng)絡(luò )單個(gè)樣本向量之間的差異,還應該學(xué)習兩個(gè)樣本之間的相關(guān)性,而這個(gè)相關(guān)性使用的是Correlation Congruence 教師網(wǎng)絡(luò )雨學(xué)生網(wǎng)絡(luò )相關(guān)性之間的歐氏距離。

整體Loss如下:

實(shí)現如下:

class Correlation(nn.Module):
    """Similarity-preserving loss. My origianl own reimplementation 
    based on the paper before emailing the original authors."""
    def __init__(self):
        super(Correlation, self).__init__()
    def forward(self, f_s, f_t):
        return self.similarity_loss(f_s, f_t)
    def similarity_loss(self, f_s, f_t):
        bsz = f_s.shape[0]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)
        G_s = torch.mm(f_s, torch.t(f_s))
        G_s = G_s / G_s.norm(2)
        G_t = torch.mm(f_t, torch.t(f_t))
        G_t = G_t / G_t.norm(2)
        G_diff = G_t - G_s
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
        return loss


*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專(zhuān)區

關(guān)閉
国产精品自在自线亚洲|国产精品无圣光一区二区|国产日产欧洲无码视频|久久久一本精品99久久K精品66|欧美人与动牲交片免费播放
<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>