<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>
"); //-->

博客專(zhuān)欄

EEPW首頁(yè) > 博客 > 知識蒸餾綜述:代碼整理(3)

知識蒸餾綜述:代碼整理(3)

發(fā)布人:計算機視覺(jué)工坊 時(shí)間:2022-01-16 來(lái)源:工程師 發(fā)布文章

11. FSP: Flow of Solution Procedure

全稱(chēng):A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning

鏈接:https://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf

發(fā)表:CVPR17

FSP認為教學(xué)生網(wǎng)絡(luò )不同層輸出的feature之間的關(guān)系比教學(xué)生網(wǎng)絡(luò )結果好

9.jpg

定義了FSP矩陣來(lái)定義網(wǎng)絡(luò )內部特征層之間的關(guān)系,是一個(gè)Gram矩陣反映老師教學(xué)生的過(guò)程。

10.jpg

使用的是L2 Loss進(jìn)行約束FSP矩陣。實(shí)現如下:

class FSP(nn.Module):
    """A Gift from Knowledge Distillation:
    Fast Optimization, Network Minimization and Transfer Learning"""
    def __init__(self, s_shapes, t_shapes):
        super(FSP, self).__init__()
        assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
        s_c = [s[1] for s in s_shapes]
        t_c = [t[1] for t in t_shapes]
        if np.any(np.asarray(s_c) != np.asarray(t_c)):
            raise ValueError('num of channels not equal (error in FSP)')
    def forward(self, g_s, g_t):
        s_fsp = self.compute_fsp(g_s)
        t_fsp = self.compute_fsp(g_t)
        loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
        return loss_group
    @staticmethod
    def compute_loss(s, t):
        return (s - t).pow(2).mean()
    @staticmethod
    def compute_fsp(g):
        fsp_list = []
        for i in range(len(g) - 1):
            bot, top = g[i], g[i + 1]
            b_H, t_H = bot.shape[2], top.shape[2]
            if b_H > t_H:
                bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
            elif b_H < t_H:
                top = F.adaptive_avg_pool2d(top, (b_H, b_H))
            else:
                pass
            bot = bot.unsqueeze(1)
            top = top.unsqueeze(2)
            bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
            top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)
            fsp = (bot * top).mean(-1)
            fsp_list.append(fsp)
        return fsp_list

12. NST: Neuron Selectivity Transfer

全稱(chēng):Like what you like: knowledge distill via neuron selectivity transfer

鏈接:https://arxiv.org/pdf/1707.01219.pdf

發(fā)表:CoRR17

使用新的損失函數最小化教師網(wǎng)絡(luò )與學(xué)生網(wǎng)絡(luò )之間的Maximum Mean Discrepancy(MMD), 文中選擇的是對其教師網(wǎng)絡(luò )與學(xué)生網(wǎng)絡(luò )之間神經(jīng)元選擇樣式的分布。

11.jpg

使用核技巧(對應下面poly kernel)并進(jìn)一步展開(kāi)以后可得:

實(shí)際上提供了Linear Kernel、Poly Kernel、Gaussian Kernel三種,這里實(shí)現只給了Poly這種,這是因為Poly這種方法可以與KD進(jìn)行互補,這樣整體效果會(huì )非常好。實(shí)現如下:

class NSTLoss(nn.Module):
    """like what you like: knowledge distill via neuron selectivity transfer"""
    def __init__(self):
        super(NSTLoss, self).__init__()
        pass
    def forward(self, g_s, g_t):
        return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
    def nst_loss(self, f_s, f_t):
        s_H, t_H = f_s.shape[2], f_t.shape[2]
        if s_H > t_H:
            f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
        elif s_H < t_H:
            f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
        else:
            pass
        f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
        f_s = F.normalize(f_s, dim=2)
        f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
        f_t = F.normalize(f_t, dim=2)
        # set full_loss as False to avoid unnecessary computation
        full_loss = True
        if full_loss:
            return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
                    - 2 * self.poly_kernel(f_s, f_t).mean())
        else:
            return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()
    def poly_kernel(self, a, b):
        a = a.unsqueeze(1)
        b = b.unsqueeze(2)
        res = (a * b).sum(-1).pow(2)
        return res

13. CRD: Contrastive Representation Distillation

全稱(chēng):Contrastive Representation Distillation

鏈接:https://arxiv.org/abs/1910.10699v2

發(fā)表:ICLR20

將對比學(xué)習引入知識蒸餾中,其目標修正為:學(xué)習一個(gè)表征,讓正樣本對的教師網(wǎng)絡(luò )與學(xué)生網(wǎng)絡(luò )盡可能接近,負樣本對教師網(wǎng)絡(luò )與學(xué)生網(wǎng)絡(luò )盡可能遠離。構建的對比學(xué)習問(wèn)題表示如下:

整體的蒸餾Loss表示如下:

實(shí)現如下:https://github.com/HobbitLong/RepDistiller

class ContrastLoss(nn.Module):
    """
    contrastive loss, corresponding to Eq (18)
    """
    def __init__(self, n_data):
        super(ContrastLoss, self).__init__()
        self.n_data = n_data
    def forward(self, x):
        bsz = x.shape[0]
        m = x.size(1) - 1
        # noise distribution
        Pn = 1 / float(self.n_data)
        # loss for positive pair
        P_pos = x.select(1, 0)
        log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
        # loss for K negative pair
        P_neg = x.narrow(1, 1, m)
        log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
        loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
        return loss
class CRDLoss(nn.Module):
    """CRD Loss function
    includes two symmetric parts:
    (a) using teacher as anchor, choose positive and negatives over the student side
    (b) using student as anchor, choose positive and negatives over the teacher side
    Args:
        opt.s_dim: the dimension of student's feature
        opt.t_dim: the dimension of teacher's feature
        opt.feat_dim: the dimension of the projection space
        opt.nce_k: number of negatives paired with each positive
        opt.nce_t: the temperature
        opt.nce_m: the momentum for updating the memory buffer
        opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
    """
    def __init__(self, opt):
        super(CRDLoss, self).__init__()
        self.embed_s = Embed(opt.s_dim, opt.feat_dim)
        self.embed_t = Embed(opt.t_dim, opt.feat_dim)
        self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
        self.criterion_t = ContrastLoss(opt.n_data)
        self.criterion_s = ContrastLoss(opt.n_data)
    def forward(self, f_s, f_t, idx, contrast_idx=None):
        """
        Args:
            f_s: the feature of student network, size [batch_size, s_dim]
            f_t: the feature of teacher network, size [batch_size, t_dim]
            idx: the indices of these positive samples in the dataset, size [batch_size]
            contrast_idx: the indices of negative samples, size [batch_size, nce_k]
        Returns:
            The contrastive loss
        """
        f_s = self.embed_s(f_s)
        f_t = self.embed_t(f_t)
        out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
        s_loss = self.criterion_s(out_s)
        t_loss = self.criterion_t(out_t)
        loss = s_loss + t_loss
        return loss

14. Overhaul

全稱(chēng):A Comprehensive Overhaul of Feature Distillation鏈接:http://openaccess.thecvf.com/content_ICCV_2019/papers/發(fā)表:CVPR19

teacher transform中提出使用margin RELU激活函數。

12.jpg

student transform中提出使用1x1卷積。

distillation feature postion選擇Pre-ReLU。

13.jpg

distance function部分提出了Partial L2 損失函數。

14.jpg

部分實(shí)現如下:

class OFD(nn.Module):
  '''
  A Comprehensive Overhaul of Feature Distillation
  http://openaccess.thecvf.com/content_ICCV_2019/papers/
  Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
  '''
  def __init__(self, in_channels, out_channels):
    super(OFD, self).__init__()
    self.connector = nn.Sequential(*[
        nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(out_channels)
      ])
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
          nn.init.constant_(m.bias, 0)
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
  def forward(self, fm_s, fm_t):
    margin = self.get_margin(fm_t)
    fm_t = torch.max(fm_t, margin)
    fm_s = self.connector(fm_s)
    mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
    loss = torch.mean((fm_s - fm_t)**2 * mask)
    return loss
  def get_margin(self, fm, eps=1e-6):
    mask = (fm < 0.0).float()
    masked_fm = fm * mask
    margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)
    return margin

參考文獻

https://blog.csdn.net/weixin_44579633/article/details/119350631

https://blog.csdn.net/winycg/article/details/105297089

https://blog.csdn.net/weixin_46239293/article/details/120289163

https://blog.csdn.net/DD_PP_JJ/article/details/121578722

https://blog.csdn.net/DD_PP_JJ/article/details/121714957

https://zhuanlan.zhihu.com/p/344881975

https://blog.csdn.net/weixin_44633882/article/details/108927033

https://blog.csdn.net/weixin_46239293/article/details/120266111

https://blog.csdn.net/weixin_43402775/article/details/109011296

https://blog.csdn.net/m0_37665984/article/details/103288582

https://blog.csdn.net/m0_37665984/article/details/103269740

本文僅做學(xué)術(shù)分享,如有侵權,請聯(lián)系刪文。

*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專(zhuān)區

關(guān)閉
国产精品自在自线亚洲|国产精品无圣光一区二区|国产日产欧洲无码视频|久久久一本精品99久久K精品66|欧美人与动牲交片免费播放
<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>