• 工作总结
  • 工作计划
  • 读后感
  • 发言稿
  • 心得体会
  • 思想汇报
  • 述职报告
  • 作文大全
  • 教学设计
  • 不忘初心
  • 打黑除恶
  • 党课下载
  • 主题教育
  • 谈话记录
  • 申请书
  • 对照材料
  • 自查报告
  • 整改报告
  • 脱贫攻坚
  • 党建材料
  • 观后感
  • 评语
  • 口号
  • 规章制度
  • 事迹材料
  • 策划方案
  • 工作汇报
  • 讲话稿
  • 公文范文
  • 致辞稿
  • 调查报告
  • 学习强国
  • 疫情防控
  • 振兴乡镇
  • 工作要点
  • 治国理政
  • 十九届五中全会
  • 教育整顿
  • 党史学习
  • 建党100周
  • 当前位置: 蜗牛文摘网 > 实用文档 > 公文范文 > 基于拓扑一致性对抗互学习的知识蒸馏

    基于拓扑一致性对抗互学习的知识蒸馏

    时间:2023-04-07 19:10:05 来源:千叶帆 本文已影响

    赖 轩 曲延云 谢 源 裴玉龙

    图像分类是计算机视觉领域的一个经典任务,有广泛的应用需求,例如机场和车站闸口的人脸识别、智能交通中的车辆检测等,图像分类的应用在一定程度上减轻了工作人员的负担,提高了工作效率.图像分类的解决方法也为目标检测、图像分割、场景理解等视觉任务奠定了基础.近年来,由于GPU等硬件和深度学习技术的发展,深度神经网络(Deep neural network,DNN)[1]在各个领域取得了长足的进展,比如,在ImageNet 大规模视觉识别挑战赛ILSVRC 比赛库上的图像分类,基于深度学习的图像分类方法已经取得了与人类几乎相同甚至超越人类的识别性能.然而,这些用于图像分类的深度学习模型往往需要较高的存储空间和计算资源,使其难以有效的应用在手机等云端设备上.如何将模型压缩到可以适应云端设备要求,并使得性能达到应用需求,是当前计算机视觉研究领域一个活跃的研究主题.轻量级模型设计是当前主要的解决途径,到目前为止,模型压缩方法大致分为基于模型设计的方法[2]、基于量化的方法[3]、基于剪枝的方法[4]、基于权重共享的方法[5]、基于张量分解的方法[6]和基于知识蒸馏的方法[7]六类.

    本文主要关注知识蒸馏方法.知识蒸馏最初被用于模型压缩[8].不同于剪枝、张量分解等模型压缩方法,知识蒸馏(Knowledge distillation,KD)的方法,先固定一个分类性能好的大模型作为教师网络,然后训练一个轻量级模型作为学生网络学习教师网络蒸馏出来的知识,在不增加参数量的情况下提升小模型的性能.基于知识蒸馏的模型压缩方法,将教师网络输出的预测分布视为软标签,用于指导学生网络的预测分布,软标签反映了不同类别信息间的隐关联,为新网络的训练提供了更丰富的信息,通过最小化两个网络预测的Kullback-Leibler (KL)散度差异,来实现知识迁移.Romero 等[9]认为让小模型直接在输出端模拟大模型时会造成模型训练困难,从而尝试让小模型去学习大模型预测的中间部分,该方法提取出教师网络中间层的特征图,通过一个卷积转化特征图大小来指导学生网络对应层的特征图.Yim 等[10]使用FSP (Flow of solution procedure)矩阵计算卷积层之间的关系,让小模型去拟合大模型层与层之间的关系.Peng 等[11]和Park等[12]同时输入多个数据,在原知识蒸馏模型的基础上通过学习样本之间的相关性进一步提升学生网络性能.

    考虑到知识蒸馏的本质是知识的迁移,即将知识从一个模型迁移到另一个模型,Zhang 等[13]提出了深度互学习(Deep mutual learning,DML)方法,设计了一种蒸馏相关的相互学习策略,在训练的过程中,学生网络和教师网络可以相互学习,知识不仅从教师网络迁移到学生网络,也从学生网络迁移到教师网络.

    协同学习也是常见的迁移学习方法之一,多用于半监督学习.在协同学习中,不同的模型或者在不同分组的数据集上学习,或者通过不同视角的特征进行学习,例如识别同一组物体类别,但其中一种模型输入RGB 图像,而另一种模式输入深度图像.协同属性学习[14]就是通过属性矩阵的融合进行属性的挖掘,从而指导两个模型的分类.而深度互学习方法中所有模型在同一数据集上训练完成相同的任务.

    尽管现有的知识蒸馏的方法已经取得了长足的进展,但仍存在以下问题:1)现有的深度互学习方法仅关注教师网络和学生网络输出的类分布之间的差异,没有利用对抗训练来提升模型的判别能力;2)现有的深度互学习仅关注结果监督,忽视了过程监督.特别是没有考虑高维特征空间中拓扑关系的一致性.针对问题1),本文设计对抗互学习框架,生成器使用深度互学习框架,通过对抗训练,提高教师和学生网络的判别性;针对问题2),本文在教师网络和学生网络互学习模型中,增加过程监督,即对中间生成的特征图,设计了拓扑一致性度量方法,通过结果和过程同时控制,提高模型的判别能力.

    总之,本文提出了一种基于拓扑一致性的对抗互学习知识蒸馏方法(Topology-guided adversarial deep mutual learning,TADML),在生成对抗[15]网络架构下,设计知识蒸馏方法,教师网络和学生网络互相指导更新,不仅让教师网络的知识迁移到学生网络,也让学生网络的知识迁移到教师网络.本文的模型框架可以推广到多个网络的对抗互学习.TADML 由深度互学习网络构成的生成器和一个判别器组成.生成器的每个子网络都是分类网络.类似于知识蒸馏,任一子网络都可以看作是其余网络的教师网络,对其他网络训练更新,进行知识迁移.为方便计算,本文将所有子网络组视为一个大网络同时优化更新.每个被看作生成器的子网络,生成输入图像的特征.判别器更新时判断生成器的输出特征属于哪一个类别、来源于哪一个子网络,而生成器更新时尽量混淆判别器使其无法准确判断特征来源于哪一个生成器,进而拟合网络中隐含的信息.

    本节介绍如何通过对抗训练框架实现网络间的知识转移.首先概述TAMDL 网络结构,然后讨论所提的损失函数的构成,最后描述模型的训练过程.

    1.1 网络结构

    如图1 所示,给出了基于拓扑一致性的对抗互学习知识蒸馏(TADML)框架,该框架由生成器和判别器两部分组成:

    图1 本文方法框架Fig.1 The framework of the proposed method

    1)生成网络.该部分由两个或多个分类子网络组成,生成器中的分类网络执行相同的分类任务,可以选取不同的模型结构,彼此间无需共享参数.不失一般性,现有的深度分类模型都可作为生成器中的分类网络,例如ResNet和Wide-ResNet[16].由于所有的生成网络使用相同的数据集执行相同的分类任务,对于输入图像x,定义第i个网络的激活函数层Softmax 的类别分布概率值为fi(x,ωi),其中ωi是相应的分类模型网络参数.

    2)判别器.在TADML 架构中,将两个或多个分类网络看作生成器,而判别器只有一个.由于常见的判别器容易陷入过早收敛或难以训练两种极端情况,本文设计了一个能较好平衡判别器稳定性和辨别能力的判别器,相对于常见的多层感知器[17]更加稳定.如图2 所示,提出的判别器由三个全连接的层(128fc-256fc-128fc)组成,且判别器的第一层和与最后一层没有批标准化处理(Batch normalization,BN)与LeakyRelu 激活函数操作.与常见的判别器不同,本文所设计判别器的输出不是简单的真假(自然图像/伪造图像),而是判断输入来源于哪个网络且隶属于哪个类别.受到条件GAN (Conditional-GAN,C-GAN[18])在图像恢复领域中的启发,本文根据C-GAN 的对判别器的输入进行改造,在后续的消融实验部分对判别器的输入进行不同程度的约束.

    图2 判别器结构图Fig.2 The structure of discriminator

    1.2 损失函数

    所提方法考虑四种损失:标签监督损失LS,对抗损失Ladv,分布一致性损失Lb,拓扑一致性损失LT.标签监督损失LS是广泛用于图像分类中带注释数据分类任务的监督损失,这对提取知识起着至关重要的作用.分布一致性损失Lb是直接匹配所有分类子网络的输出的显式损失,而对抗性损失Ladv表示隐式损失,该损失将所有分类子网络的逻辑分布之间经过分类器判断的差异最小化.换句话说,对抗性损失提供了一些通过传统分布相似性度量而丢失的信息.拓扑一致性损失LT是样本实例间隐藏的高阶结构信息.

    在训练对抗生成抗网络时,为指导网络的学习,尽可能迁移分类网络之间的知识,总的损失函数定义为:

    式中,α和β分别表示四项损失所占的权重,在本文中分别设定为α=0.6,β=0.4.下面依次对这四个部分进行详细说明.

    1)标签监督损失.该损失为常用的监督分类交叉熵损失.对于给定的图像标签对 (x;l),优化模型参数使得预测类别与标签的交叉熵降至最低,以正确预测每个训练实例的真实标签:

    2)分布一致性损失.考虑到互学习模型中的知识迁移,与之前的蒸馏网络不同,本文没有固定一个预训练网络作为教师网络进行单向指导,所提方法中任意一个网络都接受其余网络的监督指导,最小化分类网络输出特征的类别分布差异,输出越相似则表示迁移效果越好.受到Knowledge squeezed adversarial network compression (KSANC)[19]的启发,本文考虑从结果导向和过程导向两个方面同时进行知识迁移.过程导向约束仅针对最后一个全连接层的输出.最终输出的逻辑分布作为结果导向,即各个网络之间只保留网络输出之间的实例级对齐.

    考虑到网络输出的类别分布的差异性度量,本文使用Jensen-Shannon (JS) 散度衡量输出分布的相似性:

    式中,fi表示由第i个网络预测的逻辑分布.KL 散度定义为:

    3)对抗性损失.在TADML 的模型中,采用对抗学习(GAN)的方法,将从每个网络中提取的知识转移到另一个网络中.在知识蒸馏中,学生网络通过模仿教师网络从而学习教师网络中的知识,直到最后学生网络的输出与教师网络相近则视为指导完成.TADML 网络整体框架分为生成器和判别器两个部分,多个分类网络构成生成器.对于一个输入的样本,经过生成网络得到多个类别概率,每一个分类网络都对应输出一个概率分布(也可以视为图像经过这个网络表征的特征编码).这些概率分布作为判别器的输入,判别器判断类别概率分布是由哪个分类网络产生.生成器与判别器交替迭代更新,固定判别器更新生成器时,尽量生成相似的特征编码,使得判别器无法分辨特征编码来自于生成器的哪一个子网络;而在固定生成器更新判别器时,尽量训练判别网络,使其可以轻易的分辨输入来源于生成器中哪个分类子网络.二者交替迭代直到动态平衡,则视为收敛.

    到目前为止,基于GAN 的方法已在很多领域取得了显著的效果,在TADML 方法中,每个分类子网络都被视为GAN 中的生成器,并提供逻辑分布作为另一个分类子网络的真实标签.相较于原始的GAN 网络只输出一个布尔值,即真或假,本文判别器判断其输入来源于哪个分类子网络:

    式中,gn(i) 是第i个元素为1,其余元素为0 的向量,表示生成器n个分类子网络的第i个分类网络的输出作为判别器的输入,Do(fj(x)) 表示判别器输出的n位向量,代表判别器预测输入来源于哪个网络,n为分类子网络数.

    此外,如果判别器仅仅区分输入来自生成器的哪个子网络,则缺少类别信息可能导致错误的关联.为此,引入辅助分类来预测输入所属类别.即本文所提的判别器不仅需要判断输入来源于哪个分类子网络,还需要判断输入属于哪一个类别标签,损失函数表示为:

    式中,gN(C) 表示真实的类别分布,DC(fi(x)) 表示判别器输出的类别分布,N是类别总数.

    鉴于GAN 网络的判别器容易在极少的迭代次数后收敛和过度拟合.本文设计了惩罚项作为对模型的正则化处理,定义如下:

    式中,µ权重参数设为0.7,ωD是判别器的网络参数,g(0) 表示元素全为0 的向量,负号表示该项仅在式(5)最大化步骤中更新,前一项迫使判别器的权重缓慢增长,后一项则是对抗性样本正则化.

    本文设计的对抗损失为:

    4)拓扑一致性损失.在过程导向的监督学习中,考虑样本组间的拓扑结构相似性,本文选择计算样本在高维空间嵌入特征的距离及其角度的一致性.对于输入的样本组{x1,x2,x3,···,xn},经过第i个分类网络的最后一层全连接输出的特征映射看作高维嵌入特征{hi(x1),hi(x2),hi(x3),···,hi(xn)},则两个网络间基于特征距离的拓扑一致性损失可以表示为:

    1.3 训练步骤

    在训练过程中,本文交替更新判别器和生成器.在更新生成器参数时,固定判别器不动,将生成器的所有分类网络视为一个整体,通过最小化式(1)同时更新生成器中所有的分类网络参数.在更新判别器参数时,所有的生成网络都是固定的,以提供稳定的输入,通过最大化式(8)更新.交替迭代更新,每输入一组数据交替一次,直至迭代次数满足终止条件.在测试阶段,本文仅考虑作为生成器的分类子网络,并将每个分类子网络视为一个完整的分类网络来对输入图像分别进行分类.

    2.1 数据集

    本文在3 个公开的分类数据集CIFAR10、CIFAR100和Tiny-ImageNet 上进行训练和测试,进一步在行人重识别数据集Market1501 上验证所提方法的有效性.其中,CIFAR100和CIFAR10 数据集都包含60 000 张32 × 32 像素大小的图像,分别由100 个类和10 个类组成,50 000 张用于训练,10 000 张用于验证.Tiny-ImageNet 源于ImageNet dataset (1 000 个类别),从中抽取200 个类别,每个类别有500 个训练图像,50 个验证图像和50个测试图像,且所有图片都被裁剪放缩为64 × 64像素大小.Market1501 是常用的行人重识别数据集,包含12 936 张训练图像(751 个不同的行人)和19 732 张测试图像(750 个不同的行人),图像大小为64 × 128 像素.

    2.2 实现细节

    本文算法使用Torch0.4 在NVIDIA GeForce GTX 1 080 GPU 上实现.对于所有分类数据集,均使用随机梯度下降法进行优化,将权重衰减设置为0.0001,动量设置为0.9.对于CIFARs 的实验,批量大小设置为64,生成网络和判别器的初始学习率分别设置为0.1和0.001,每隔80 次迭代两者都缩小为0.1 倍,总共训练了200 次迭代.对于Tiny-ImageNet 的实验,批量大小设置为128,总迭代次数为330 代,生成网络初始学习率设为0.1,每隔60 代学习率乘以0.2,判别网络初始学习率为0.001,每隔120 代乘以0.1.对于Market1501 的实验,采用与DML 相同的实验设置:使用Adam 优化器,学习率为0.0002,β1设为0.5,β2设为0.999,批量大小设置为16,图像输入大小为64 × 160 像素,共迭代100 000 次.尽管使用预训练模型能得到更高的精度,在实验中,所有网络都采用随机初始化的.由于训练前期网络变化较大,仅在总迭代次数过半的时候才加入拓扑一致性损失更新网络,且用上一次迭代时分类精度高的网络指导精度低的网络,而不是互相指导学习.

    2.3 消融实验

    关于损失函数的选择,本文尝试不同损失组合的效果.表1 展示了在CIFAR10和CIFAR100 上,将两个ResNet32 设置为生成器中的教师网络和学生网络,遵循相同的实验方案进行训练,并选择这两个子网络的平均精度作为最终结果.其中,LS表示标签损失,Lp(p=1,2) 表示两个网络输出分布之间的l1,l2范数损失,LJS表示两个网络输出分布的LJS散度相似性,Ladv表示本章提出的对抗损失.从表中可知,单独使用类别标签监督损失LS在所有组合中结果最差,增加任意一种知识迁移的损失都能增加预测的精度,LS+LJS+Ladv取得最高的平均分类精度,在CIFAR10和CIFAR100 上增幅分别为0.62%和2.28%在固定类别标签监督损失LS和对抗损失Ladv的情况下,对比增加L2和JS损失,前者增加LJS比增加L2使得分类性能有所提升,在两个数据集上的增幅分别为0.48%和0.78%.综上所述,在后续的实验中,单独使用LJS差异来计算Lb.

    表1 损失函数对分类精度的影响比较(%)Table 1 Comparison of classification performance with different loss function (%)

    进一步讨论判别器结构对TAMDL 性能的影响.在CIFAR100 上进行实验,在分类子网络固定为ResNet32 的情况下,讨论判别器采用不同的架构对最终网络的分类误差的影响.由表2 可以看出,不同结构的判别器对结果的影响不大.尝试了两层到四层不同容量的全连接层模型,且为了尽可能保留输入数据的差异性,仅在全连接层之间进行BN与LeakyReLU 操作.实验表明四层全连接层的效果普遍会略低于三层的效果,三层结构的判别器取得了略优的分类性能,128fc-256fc-128fc在CIFAR100 上取得了最好的分类性能,相比最差的四层结构的判别器128fc-256fc-256fc-128fc 分类精度仅提高了0.28.为此,在后续实验中,TAMDL采用三层结构的判别器.

    表2 判别器结构对分类精度的影响比较(%)Table 2 Comparison of classification performance with different discriminator structures (%)

    本节讨论判别器的输入对TAMDL 性能的影响.在2 个ResNet32 构成的网络上进行了实验.对比了不同的判别器的输入:1) Conv4 表示图像经过第4 组卷积得到的特征;2) FC 表示单张图像经过全连接层转化但未经Softmax 的特征;3) DAE 表示原始图像经过深度自编码器得到的压缩特征;4)Label 表示分类标签的热编码;5) Avgfc 表示一组图像经过全连接层转化但未经Softmax 的特征的平均值.表3 对比了针对不同判别器输入网络的最终结果,表中的结果是经过分类网络输出的平均值.由表3 可以看出,FC 得到的特征作为判别器的输入取得了最好的判别性能,增加的条件约束信息对最终结果没有正面的促进,如FC+Conv4 判别器的性能并没有提升,反而下降了0.44%.FC+Label 作为输入,判别器性能仅次于FC 作为输入得到的结果.

    表3 判别器输入对分类精度的影响比较(%)Table 3 Comparison of classification performance with different discriminator inputs (%)

    进一步讨论采样数量对TAMDL 分类性能的影响.在训练过程中通常采用从训练数据集中随机采样来训练网络.不加限制的随机采样器可能会导致所有样本都来自不同类别的情况.尽管它是对实例一致性的真实梯度的无偏估计,但是在本节提出的样本组间结构相似性损失计算中,过多的样本类别数容易导致组间关系过于复杂难以学习优化,且过少的样本类别数又容易导致类间相关性偏差较大.为了正确的传递样本组间的真实相关信息,采样策略十分重要.在批量输入大小固定为64 的情况下,对样本组中的类别数目进行了限定.表4 给出了在CIFAR100 数据集上,学生和教师网络为ResNet32和ResNet110 时的分类结果,其中每个样本组中类别总数为K且每类的样本数目为64/K,Random 表示不进行采样约束的互学习结果,Vanila表示原始网络精度.由表4 可知,当类别总数K取值过小时,网络无法正常训练或过早陷入过拟合状态.如K=2,TADML 取得最低的分类性能.当K取值刚好等于类别总数时,即每个类别样本仅出现一次,网络的性能与随机采样效果基本保持一致.在K=8,16,32 时,TAMDL 的性能均优于随机采样的方式,增幅分别为0.31%、0.72%、0.38%.由此可知,样本组的类别数在平衡类间内相关一致性中有很重要的作用,选取适当的类别数,后续实验采用K=16.

    表4 采样数量对分类精度的影响比较(%)Table 4 Comparison of classification performance with different sampling strategies (%)

    2.4 TAMDL 与DML 比较实验

    本节讨论TAMDL 与DML 的性能对比.为了说明TAMDL 的鲁棒性和优越性,实验设置不同结构的分类网络作为生成器,并与原始分类网络和深度互学习方法(DML)进行比较.对比实验的优化器参数设置与本文提出算法保持一致,DML 算法优化步骤按照原文的设置,使用KL 散度进行知识迁移并交替训练子网络.为了进一步说明本文所提两个损失模块的有效性,把仅加上对抗损失模块的网络(损失函数未加拓扑一致性损失度量)定义为ADML.实验部分列出了ADML 算法与同时使用对抗性损失模块、拓扑一致性损失模块的TADML算法的测试结果.由表5 可以看出,本文方法在ResNet32,ResNet110和Wide-ResNet (WRN)之间的几乎所有组合中,都比DML 表现更好,无论两个网络是同等大小,还是一大一小,大网络几乎都可以从小网络中进一步获益,从而达到更高的精度.换句话说,ADML 进一步提升了所有网络的能力.表5 中除第1 行外,第2~5 行所有的教师和学生网络结构模型,ADML 的性能都优于DML.学生网络(第1 列)的第2~5 行增幅分别为1.04%、0.49%、0.71%、1.03%,教师网络(第2 列)的第2~5 行增幅分别为0.1%、0.55%、0.74%、0.32%.当在CIFAR10 上重复相同的实验时,由于生成网络的输出过于简单导致基于GAN的优化难以收敛,提出的ADML 的性能几乎等于DML.

    由表5 可以看出,TADML 在所有的网络结构试验中几乎都达到了最优的结果,最优值用黑体标记,次优值用下划线标记.相对于DML,TADML在所有设置的网络结构中都优于DML,学生网络的增幅分别为1.21%、1.52%、0.93%、0.91%和1.52%,教师网络的增幅分别为1.24%、0.78%、1.16%、1.07%和1.01%.进一步可以发现,当2 个分类子网络大小不一致时,较大网络的提升效果远没有较小网络明显.

    表5 网络结构对分类精度的影响比较(%)Table 5 Comparison of classification performance with different network structures (%)

    将本文方法用于行人再识别,用平均识别精度mAP 进行度量.为公平比较起见,采用了与DML[13]在行人在识别实验中相同的网络设置,设置了2 组不同网络学生和教师的架构:网络1(InceptionV,MobileNetV1)、网络2 (MobileNetV1,MobileNetV1).对比DML、ADML和TADML,结果如表6所示.在行人重识别数据集上的性能进一步表明了,本文算法的有效性和优越性.ADML 相对于DML,2 组师生网络性能分别提升了0.26%和0.35%、0.47%和1.01%;TADML 相对于DML,两组师生网络性能分别提升了0.59%和1.04%、0.89%和1.39%.实验结果表明,ADML和TADML 方法在Market1501数据集上的mAP 普遍高于DML.

    表6 网络结构对行人重识别平均识别精度的影响比较(%)Table 6 Comparison of person re-identification mAP with different network structures (%)

    2.5 主流方法对比

    将本文TAMDL 方法与当前流行的方法进行比较,为比较公平,将模型压缩的性能作为比较指标,在三个常见的分类数据集CIFAR10、CIFAR100、Tiny-ImageNet 上进行比较.对比了9 种方法,分别为2 种广泛使用的基于量化的模型压缩方法:Quantization[20]、Binary Connect[21],4 种常见的知识蒸馏方法:解过程流方法(Flow of solution procedure,FSP)[10]、模拟浅层神经网络的SNN-MIMIC 方法[22]、KD[8]、用浅而宽的教师网络训练窄而深的学生网络的FitNet[9],3 种对抗训练的蒸馏方法:对抗网络压缩方法(Adversarial network compression,ANC[23]、用条件对抗学习加速训练学生网络的TSANC 方法[24]、用知识挤压进行对抗学习的KSANC 方法[19].其中Quantization[20]将网络权重的进行三值化,Binary Connect[21]在前向和后向传递期间对权重进行二值化.SNN-MIMIC[22]模拟学习L2损失,KD[8]通过KL 散度进行软目标的知识转移,Yim 等[10]使用FSP 矩阵进行蒸馏,FitNet[9]使用更深但更薄的网络尝试迁移模型中间层的知识.ANC[23]首次将生成对抗网络融入到知识蒸馏中对学生网络的逻辑分布层进行指导,TSANC[24]在此基础上对判别器的输入进行了条件约束,KSANC[19]进一步加入了网络中间层的监督指导.

    在对比实验中,教师网络使用ResNet164,学生网络使用ResNet20.其中Tiny-ImageNet 的实验结果由复现的代码运行得到,表中的其余结果均来自自文献[19],一些对比方法未给出实验结果,则标记为 “-”.如表7 所示,第1 行ResNet20 为学生网络的分类性能,第2 行ResNet164 为教师网络的性能.从第2 行至最后一行为在相同的教师和学生网络设置下,对比方法仅使用学生网络进行分类达到的分类性能.第1 列为对比方法,第2 列为模型大小.最优值使用黑色粗体标记,次优值使用下划线粗体标记.本文方法TAMDL 在3 个数据集上均取得了最高的分类精度,与最新的对比方法KSANC比较,在CIFAR10、CIFAR100和Tiny-ImagNet上增幅分别为0.37%、2.23%和0.34%.

    表7 本文算法与其他压缩算法的实验结果Table 7 Experimental results of the proposed algorithm and other compression algorithms

    由表7 可以看出,学生网络都没能达到教师网络的性能.对于CIFAR10,在相同规模下采用对抗学习后,学生网络的性能得到改善,ANC、TSANC、KSANC、AMDL、TAMDL 的增幅分别为0.5%、0.75%、1.26%、0.81%和2.63%.对于类别复杂的CIFAR100,增幅更为明显,以上5 种方法的增幅分别为0.92%、0.80%、1.95%、2.97%和4.81%.对于更为复杂的Tiny-ImageNet 数据集,以上五种方法的增幅分别为3.72%、3.75%、5.32%、4.55%和5.66%.比较实验表明,数据集越复杂,对抗训练的提升效果越明显,本文方法TAMDL 相对于其他对比方法优势越明显.

    2.6 模型复杂性分析

    本节以ResNet164/ResNet20 做为教师网络/学生网络为例,来分析TAMDL 模型的复杂性.在训练阶段,先固定判别器,此时优化生成器—两个分类网络ResNet164和ResNet20,两个模型的参数量分别为2.61 MB和0.27 MB,即生成器参数量为2.88 MB,耗时与传统互学习网络一致;优化判别器时,生成器固定不动,此时优化的是一个多层感知器—三个全连接层128-256-128,参数量为0.59 MB.在训练时生成器和判别器以1:1 的轮次交替迭代,在数据集CIFAR100 使用Pytorch0.4进行实验,生成器为ResNet164+ResNet20,判别网络为三个维度为128-256-128 的全连接层,批尺寸Batchsize 设为64,即每个训练轮次Epoch 将训练集划分为781个Batch,平均每训练轮次Epoch耗时82 s,其中每个Batch平均耗时0.1045 s,优化生成器反向传播耗时0.0694 s,优化判别器反向传播耗时0.0016 s.采用对抗训练,并没有带来太大的时间开销.

    本文提出了一种拓扑一致性指导的对抗互学习知识蒸馏方法.该方法在GAN 框架下,对轻量级的学生网络进行知识迁移,所提方法设计了样本组间拓扑一致性度量,依此设计的损失函数结合常规的实例级别的分布相似性,以及对抗损失及标号损失,作为训练模型的总损失.文中评估了不同损失函数和不同模型架构对分类精度的影响.在3 个公开的数据集上验证了本文方法TAMDL 的有效性.本文方法效果稳定且提升明显,而且在压缩模型的性能比较中,取得最好的结果.

    猜你喜欢类别损失分类胖胖损失了多少元数学小灵通·3-4年级(2021年5期)2021-07-16分类算一算数学小灵通(1-2年级)(2021年4期)2021-06-09分类讨论求坐标中学生数理化·七年级数学人教版(2019年4期)2019-05-20玉米抽穗前倒伏怎么办?怎么减少损失?今日农业(2019年15期)2019-01-03数据分析中的分类讨论中学生数理化·七年级数学人教版(2018年6期)2018-06-26壮字喃字同形字的三种类别及简要分析民族古籍研究(2018年1期)2018-05-21教你一招:数的分类初中生世界·七年级(2017年9期)2017-10-13服务类别新校长(2016年8期)2016-01-10一般自由碰撞的最大动能损失广西民族大学学报(自然科学版)(2015年3期)2015-12-07损失读者·校园版(2015年19期)2015-05-14
    相关热词搜索:拓扑蒸馏对抗

    • 名人名言
    • 伤感文章
    • 短文摘抄
    • 散文
    • 亲情
    • 感悟
    • 心灵鸡汤