欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

[29]知识蒸馏(knowledge distillation)测试和利用可学习参数辅助知识蒸馏训练学生模型--这里主要是看了知识蒸馏测试的效果,其中包含温度设置带来的变化;以及可学习参数的设置,设置了一个通过调整两个参数来拟合直线的小实验!最后,知识蒸馏对学生网络的影响与可学习参数对知识蒸馏过程的影响进行了对比测试。

最编程 2024-03-06 08:02:04
...

其中,对于知识蒸馏的理论介绍见:知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍

@[toc]

1. Temperature Control


在我的上一篇文章中:知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍,介绍了知识蒸馏的温度控制,这里展示一下不同的温度对logits带来的进行,进行一个具体的可视化展示不同温度之间的区别。

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
# 设置类别输出的logit
logits = np.array([-5, 2, 7, 9])
labels = ['cat', 'dog', 'donkey', 'horse']
  • 普通的softmax(T = 1)

对于普通的softmax函数来说,对某个类别经过softmax的计算公式为: 在这里插入图片描述

# 计算普通的softmax函数已经以及绘图
# 其实普通的softmax函数就是T=1时的情况
softmax_1 = np.exp(logits) / sum(np.exp(logits))
plt.plot(softmax_1, label="softmax_1")
plt.legend()
plt.show()

在这里插入图片描述

  • 知识蒸馏的softmax(T = k)

对于知识蒸馏的温度系数,一般会大于1,当T越大时输出的类别概率越平滑;而当T越小时输出的类别概率差别越大,曲线越尖锐。 对于知识蒸馏的softmax函数来说,对某个类别经过softmax的计算公式为: 在这里插入图片描述

# 设置不同的温度系数来展示最后输出概率的区别
T = 0.6
softmax_06 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_06, label="softmax_06")

T = 0.8
softmax_08 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_08, label="softmax_08")

# 保留softmax_1以供不同T的对比
softmax_1 = np.exp(logits) / sum(np.exp(logits))
plt.plot(softmax_1, label="softmax_1")

T = 3
softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_3, label="softmax_3")

T = 5
softmax_3 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_3, label="softmax_3")

T = 10
softmax_10 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_10, label="softmax_10")

T = 100
softmax_100 = np.exp(logits/T) / sum(np.exp(logits/T))
plt.plot(softmax_100, label="softmax_100")

plt.xticks(np.arange(4), labels=labels)
plt.legend()
plt.show()

在这里插入图片描述

2. Learnable Parameters


这里我设置了一个实验,基于两个点决定一条直线,这里我设置的测试点是(-1,-1)(3,2),基于这两个点所决定的直线应该为:y=0.75x0.25y=0.75x-0.25。那么一条直线可以有斜率k和截距b所决定:y=kx+by=kx+b所以设置两个可学习的参数k与b,开始对其进行随机初始化,之后通过不断迭代到一个差不多的的数值,使用的函数是均方差损失函数。

在可学习参数的训练流程中,我也会进行可视化展示,来查看是如何一步一步拟合到最后的直线的,代码如下:

import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
# 设置随机种子,保证结果可复现
def SetSeed(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

SetSeed()
# 定义一个绘图函数,每隔一定的epoch就画出数据点与直线的关系
def Drawing(points, params, figsize=(8, 4)):
    # points: [[x1,y1],[x2,y2]...]
    # params: [k, b]
    k = params[0].item()
    b = params[1].item()
    x = np.linspace(-5.,5.,100)
    y = k*x+b
    
    plt.figure(figsize=figsize)
    # 根据points:画出数据点的分布情况
    plt.scatter(points[:, 0], points[:, 1], marker="^", c="blue")
    # 根据params:画出直线的拟合情况
    plt.plot(x, y, c="red")
    # 设置图表格式
    plt.title("line fit")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()
    plt.close()
# 设置一些训练的参数
epoch = 40
learning_rate = 3e-2

# 准备好数据
points = torch.tensor([[-1., -1.], [3., 2.]], dtype=torch.float32)
targets = point[:, 1]
inputs = point[:, 0]
print("inputs:",inputs, "targets:",targets)

# 设置两个可学习参数
k = nn.Parameter(torch.randn(1), requires_grad=True)
b = nn.Parameter(torch.randn(1), requires_grad=True)
params = [k, b]
print(params)

# 设置优化器与损失函数
optimizer = optim.Adam(params, lr=learning_rate)
criterion = nn.MSELoss()

# 训练两个参数
loss_lists = []
for i in range(epoch):
    optimizer.zero_grad()
    outputs = inputs*k + b
    loss = criterion(outputs, targets)
    loss_lists.append(loss.item())
    
    loss.backward()
    optimizer.step()
    
    if (i+1) % 4 == 0:
        Drawing(points, [k,b])
#         print("outputs:",outputs)
#         print("k:", k)
#         print("b:", b)

# 查看训练后的参数
print("k:", k)
print("b:", b)

输出:

# 展示图片前的输出结果
inputs: tensor([-1.,  3.]) targets: tensor([-1.,  2.])
[Parameter containing:
tensor([-0.0223], requires_grad=True), Parameter containing:
tensor([0.3827], requires_grad=True)]

# 展示完图片后的输出结果
k: Parameter containing:
tensor([0.7866], requires_grad=True)
b: Parameter containing:
tensor([-0.3185], requires_grad=True)

在这里插入图片描述在这里插入图片描述

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gpkJIA8I-1647662257023)(output_12_3.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Vh3jTFxN-1647662257025)(output_12_4.png)]在这里插入图片描述在这里插入图片描述在这里插入图片描述

在这里插入图片描述在这里插入图片描述

3. Knowledge Distillation


这里我的想法是通过搭建两个神经网络,一个大网络一个小网络,查看小网络知识蒸馏前后的效果。ps:这里的大神经网络模型也可以有CNN模型替换

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

import tqdm
from thop import profile
# 定义教师网络与学生网络
class Teacher(nn.Module):
    def __init__(self, num_classes=10):
        super(Teacher, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, 1200),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(1200, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)


class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 20),
            nn.ReLU(),
#             nn.Linear(20, 20),
#             nn.ReLU(),
            nn.Linear(20, num_classes)
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)


# 测试网络的参数量与浮点计算量
def test_flops_params():
    x = torch.rand([1, 1, 28, 28])
#     x = x.reshape(1, 1, -1)
    T_network = Teacher()
    S_network = Student()
    t_flops, t_params = profile(T_network, inputs=(x, ))
    print('t_flops:{}e-3G'.format(t_flops / 1000000), 't_params:{}M'.format(t_params / 1000000))
    s_flops, s_params = profile(S_network, inputs=(x, ))
    print('s_flops:{}e-3G'.format(s_flops / 1000000), 's_params:{}M'.format(s_params / 1000000))

test_flops_params()

输出结果:可以看见教师网络的参数量和浮点计算量都是学生网络的上百倍

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.Teacher'>. Treat it as zero Macs and zero Params.[00m
t_flops:2.3928e-3G t_params:2.39521M
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.Student'>. Treat it as zero Macs and zero Params.[00m
s_flops:0.01588e-3G s_params:0.01591M
# 设置超参数
epoch_size = 5
batch_size = 128
learning_rate = 1e-4

# 训练集下载
train_data = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# 测试集下载
test_data = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.1307])
]), download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# 定义模型
device = torch.device('cuda')
t_model = Teacher().to(device)
s_model = Student().to(device)

# 定义优化器与损失
criterion = nn.CrossEntropyLoss().to(device)

# 训练过程
def train_one_epoch(model, criterion, optimizer, dataloader):
    
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        outputs = model(image)
        loss = criterion(outputs, targets)
        train_loss += loss.item()
        # 反向更新训练
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 计算正确个数
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)

# 测试过程
def validate(model, criterion, dataloader):
    
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_info = "Test ==> loss:{}, acc:{} ({}/{})"\
          .format(test_loss/len(dataloader), correct/total, correct, total)
    print(test_info)
  • 查看教师网络的效果
# 定义教师网络的优化器
t_optimizer = optim.Adam(t_model.parameters(), lr=learning_rate)
# 训练教师网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(t_model, criterion, t_optimizer, train_loader)
    validate(t_model, criterion, test_loader)
# 训练好教师模型后,先保存教师网络的模型参数,有需要再重新导入即可
torch.save(t_model.state_dict(), "t_model.mdl")

训练教师网络的输出结果:

[Epoch:0]
batch:0/469, loss:2.451192855834961, acc:0.046875 (6/128)
batch:93/469, loss:0.9547713903036523, acc:0.7139295212765957 (8590/12032)
batch:186/469, loss:0.6837506957232633, acc:0.7931149732620321 (18984/23936)
batch:279/469, loss:0.5656270824904953, acc:0.8287388392857142 (29702/35840)
batch:372/469, loss:0.49507016175234286, acc:0.8506618632707775 (40614/47744)
batch:465/469, loss:0.44825181595245656, acc:0.864538626609442 (51568/59648)
Test ==> loss:0.1785062526977515, acc:0.9461 (9461/10000)
[Epoch:1]
batch:0/469, loss:0.2513591945171356, acc:0.9375 (120/128)
batch:93/469, loss:0.22307029120782587, acc:0.9340093085106383 (11238/12032)
batch:186/469, loss:0.2112935145988184, acc:0.9371657754010695 (22432/23936)
batch:279/469, loss:0.2037411120853254, acc:0.9394252232142857 (33669/35840)
batch:372/469, loss:0.19692633960186318, acc:0.9414167225201072 (44947/47744)
batch:465/469, loss:0.19179708627352388, acc:0.9429821620171673 (56247/59648)
Test ==> loss:0.11782180916376506, acc:0.9622 (9622/10000)
[Epoch:2]
batch:0/469, loss:0.15417274832725525, acc:0.9453125 (121/128)
batch:93/469, loss:0.14990339277589576, acc:0.9549534574468085 (11490/12032)
batch:186/469, loss:0.14525708044196833, acc:0.9562583556149733 (22889/23936)
batch:279/469, loss:0.14779337254752006, acc:0.9551060267857143 (34231/35840)
batch:372/469, loss:0.1445239940433496, acc:0.9564133713136729 (45663/47744)
batch:465/469, loss:0.14156085961846068, acc:0.9572659603004292 (57099/59648)
Test ==> loss:0.09912304252480404, acc:0.9695 (9695/10000)
[Epoch:3]
batch:0/469, loss:0.09023044258356094, acc:0.984375 (126/128)
batch:93/469, loss:0.11060939039638702, acc:0.9670877659574468 (11636/12032)
batch:186/469, loss:0.11260852741605458, acc:0.9668699866310161 (23143/23936)
batch:279/469, loss:0.11275576776159661, acc:0.9667410714285715 (34648/35840)
batch:372/469, loss:0.11253649257023597, acc:0.9668440013404825 (46161/47744)
batch:465/469, loss:0.11281515193839314, acc:0.9665873122317596 (57655/59648)
Test ==> loss:0.0813662743231258, acc:0.9734 (9734/10000)
[Epoch:4]
batch:0/469, loss:0.10590803623199463, acc:0.9765625 (125/128)
batch:93/469, loss:0.0938354417523171, acc:0.9718251329787234 (11693/12032)
batch:186/469, loss:0.09741261341474591, acc:0.9707971256684492 (23237/23936)
batch:279/469, loss:0.0959280665631273, acc:0.9712332589285714 (34809/35840)
batch:372/469, loss:0.09434855888140745, acc:0.9716823056300268 (46392/47744)
batch:465/469, loss:0.09377776978481481, acc:0.9719521190987125 (57975/59648)
Test ==> loss:0.07517792291562014, acc:0.975 (9750/10000)

以上过程,已经成功的训练出了一个参数量稍大的教师模型,那么下面看看如果只是简单的训练一个学生网络的效果会是怎么样的呢?

训练学生网络的输出结果:

# 定义学生网络的优化器
s_optimizer = optim.Adam(s_model.parameters(), lr=learning_rate)
# 训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch(s_model, criterion, s_optimizer, train_loader)
    validate(s_model, criterion, test_loader)
[Epoch:0]
batch:0/469, loss:2.435654878616333, acc:0.0859375 (11/128)
batch:93/469, loss:1.8272213606124228, acc:0.4320146276595745 (5198/12032)
batch:186/469, loss:1.431586041170008, acc:0.59295621657754 (14193/23936)
batch:279/469, loss:1.1955812790564129, acc:0.6700892857142857 (24016/35840)
batch:372/469, loss:1.0443579082354784, acc:0.7167392761394102 (34220/47744)
batch:465/469, loss:0.9376831678233944, acc:0.7478205472103004 (44606/59648)
Test ==> loss:0.46302026283891895, acc:0.8882 (8882/10000)
[Epoch:1]
batch:0/469, loss:0.5130173563957214, acc:0.890625 (114/128)
batch:93/469, loss:0.45758569145456274, acc:0.8769946808510638 (10552/12032)
batch:186/469, loss:0.44705012376933173, acc:0.8820604946524064 (21113/23936)
batch:279/469, loss:0.43086895591446334, acc:0.8864676339285714 (31771/35840)
batch:372/469, loss:0.41966651623754014, acc:0.8885723860589813 (42424/47744)
batch:465/469, loss:0.4076770568993982, acc:0.8913123658798283 (53165/59648)
Test ==> loss:0.3379575525280796, acc:0.9091 (9091/10000)
[Epoch:2]
batch:0/469, loss:0.46021485328674316, acc:0.8671875 (111/128)
batch:93/469, loss:0.35463421569859727, acc:0.8988530585106383 (10815/12032)
batch:186/469, loss:0.3480192156717739, acc:0.9024481951871658 (21601/23936)
batch:279/469, loss:0.34022124212767396, acc:0.9047154017857143 (32425/35840)
batch:372/469, loss:0.33286028996549405, acc:0.9072553619302949 (43316/47744)
batch:465/469, loss:0.32942312831147036, acc:0.9081109173819742 (54167/59648)
Test ==> loss:0.291702199491519, acc:0.9215 (9215/10000)
[Epoch:3]
batch:0/469, loss:0.2687709629535675, acc:0.9453125 (121/128)
batch:93/469, loss:0.29896670643319473, acc:0.9164727393617021 (11027/12032)
batch:186/469, loss:0.3032062678413595, acc:0.9152322860962567 (21907/23936)
batch:279/469, loss:0.2976516788559301, acc:0.9162946428571429 (32840/35840)
batch:372/469, loss:0.2963846751735933, acc:0.9160941689008043 (43738/47744)
batch:465/469, loss:0.29447377907999595, acc:0.9167784334763949 (54684/59648)
Test ==> loss:0.2693275752701337, acc:0.9252 (9252/10000)
[Epoch:4]
batch:0/469, loss:0.21400471031665802, acc:0.9296875 (119/128)
batch:93/469, loss:0.2811283932087269, acc:0.922124335106383 (11095/12032)
batch:186/469, loss:0.2739176594796665, acc:0.9235461229946524 (22106/23936)
batch:279/469, loss:0.27122129941625256, acc:0.9234933035714286 (33098/35840)
batch:372/469, loss:0.2737213251737743, acc:0.9226290214477212 (44050/47744)
batch:465/469, loss:0.27158979208172646, acc:0.9234173819742489 (55080/59648)
Test ==> loss:0.25467539342898354, acc:0.9293 (9293/10000)

由上面的测试结果可以看出,学生网络的最高正确率为0.9215;而教师网络的最高正确率为0.9769;也就是教师网络要比学生网络要好的,那么下面就是进行知识蒸馏,来看看教师网络带给学生网络的提升。

# 训练过程
def train_one_epoch_kd(s_model, t_model, hard_loss, soft_loss, optimizer, dataloader):
    
    s_model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        
        # 教师网络预测
        with torch.no_grad():
            teacher_preds = t_model(image)
        # 学生模型预测
        student_preds = s_model(image)
        
        # 计算与真实标签的损失:hard loss
        student_loss = hard_loss(student_preds, targets)
        # 计算与教师输出的损失:soft loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        
        # 总损失即为:hard loss与soft loss的加权和
        loss = alpha * student_loss + (1-alpha) * ditillation_loss
        train_loss += loss.item()
        
        # 反向更新训练
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 计算正确个数
        _, predicted = student_preds.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)
# 设置知识蒸馏的超参数
temp = 7      # 蒸馏温度
alpha = 0.3   # 权重系数

# 准备新的学生模型/优化器/损失函数
kd_model = Student().to(device)
hard_loss = nn.CrossEntropyLoss()                # 包含softmax操作
soft_loss = nn.KLDivLoss(reduction="batchmean")  # 不包含softmax操作(所以可以自己设定温度系数)

# 定义蒸馏学生网络的优化器
kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate)

# 利用知识蒸馏来训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]".format(epoch))
    train_one_epoch_kd(kd_model, t_model, hard_loss, soft_loss, kd_optimizer, train_loader)
    validate(kd_model, criterion, test_loader)
[Epoch:0]
batch:0/469, loss:-0.7217433452606201, acc:0.0625 (8/128)
batch:93/469, loss:-0.8849591688906893, acc:0.36826795212765956 (4431/12032)
batch:186/469, loss:-0.9996793939468057, acc:0.5336313502673797 (12773/23936)
batch:279/469, loss:-1.0756173231772015, acc:0.6284319196428572 (22523/35840)
batch:372/469, loss:-1.1279368988630278, acc:0.6846724195710456 (32689/47744)
batch:465/469, loss:-1.164531718252043, acc:0.7206444474248928 (42985/59648)
Test ==> loss:0.4428194357624537, acc:0.8843 (8843/10000)
[Epoch:1]
batch:0/469, loss:-1.308181643486023, acc:0.8828125 (113/128)
batch:93/469, loss:-1.333051804532396, acc:0.8811502659574468 (10602/12032)
batch:186/469, loss:-1.3385243995941896, acc:0.8837316176470589 (21153/23936)
batch:279/469, loss:-1.342562198638916, acc:0.8844029017857142 (31697/35840)
batch:372/469, loss:-1.3486905705193732, acc:0.8871062332439679 (42354/47744)
batch:465/469, loss:-1.3531442113188714, acc:0.8896190987124464 (53064/59648)
Test ==> loss:0.32884856549244895, acc:0.907 (9070/10000)
[Epoch:2]
batch:0/469, loss:-1.3262287378311157, acc:0.859375 (110/128)
batch:93/469, loss:-1.3759282218649032, acc:0.901595744680851 (10848/12032)
batch:186/469, loss:-1.3805451450500896, acc:0.9045788770053476 (21652/23936)
batch:279/469, loss:-1.381890092577253, acc:0.903515625 (32382/35840)
batch:372/469, loss:-1.3839336115936811, acc:0.9044277815013405 (43181/47744)
batch:465/469, loss:-1.3856471659287874, acc:0.9056296942060086 (54019/59648)
Test ==> loss:0.30868444205084933, acc:0.9142 (9142/10000)
[Epoch:3]
batch:0/469, loss:-1.3871097564697266, acc:0.90625 (116/128)
batch:93/469, loss:-1.398324375456952, acc:0.9108211436170213 (10959/12032)
batch:186/469, loss:-1.4016364086120523, acc:0.9119318181818182 (21828/23936)
batch:279/469, loss:-1.4019242635795048, acc:0.9117466517857142 (32677/35840)
batch:372/469, loss:-1.4023852696687222, acc:0.9129314678284183 (43587/47744)
batch:465/469, loss:-1.403283029666786, acc:0.9131907188841202 (54470/59648)
Test ==> loss:0.2895406449708757, acc:0.9191 (9191/10000)
[Epoch:4]
batch:0/469, loss:-1.4002737998962402, acc:0.8828125 (113/128)
batch:93/469, loss:-1.4150411626125903, acc:0.9187998670212766 (11055/12032)
batch:186/469, loss:-1.415715930933621, acc:0.9188669786096256 (21994/23936)
batch:279/469, loss:-1.4160895236900874, acc:0.9184709821428572 (32918/35840)
batch:372/469, loss:-1.4162878402116792, acc:0.9184400134048257 (43850/47744)
batch:465/469, loss:-1.4160627477158805, acc:0.9182202253218884 (54770/59648)
Test ==> loss:0.2826786540165732, acc:0.921 (9210/10000)

很可惜,经过短暂的测试,没有体现出知识蒸馏的优点...

4. Learnable Parameters for Knowledge Distillation


对于上面的实验是使用了一个固定的权重参数alpha来控制Distillation loss与Student loss的叠加和,那么能否设置两个可学习的参数来在线的选择权重来控制温度系数呢,下面就是围绕我这个想法进行的实验。

这里由于想要使用可学习的参数来在线的调整权重,所以需要重新设置训练函数:

# 训练过程
def train_one_epoch_kd(s_model, t_model, hard_loss, soft_loss, optimizer, parms_optimizer, dataloader):
    
    s_model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (image, targets) in enumerate(dataloader):
        image, targets = image.to(device), targets.to(device)
        
        # 教师网络预测
        with torch.no_grad():
            teacher_preds = t_model(image)
        # 学生模型预测
        student_preds = s_model(image)
        
        # 计算与真实标签的损失:hard loss
        student_loss = hard_loss(student_preds, targets)
        # 计算与教师输出的损失:soft loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        
        # 总损失即为:hard loss与soft loss的加权和
        # 这里经过测试:一个可学习的参数就足够了,使用两个反而效果不好
        # loss = alpha * student_loss + gama * ditillation_loss
        loss = alpha * student_loss + (1 - alpha) * ditillation_loss
        train_loss += loss.item()
        
        # 反向更新训练
        optimizer.zero_grad()		
        parms_optimizer.zero_grad()	
        loss.backward()
        optimizer.step()		# 网络优化器更新
        parms_optimizer.step()	# 可学习参数优化器更新
        
        # 计算正确个数
        _, predicted = student_preds.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % int(len(dataloader)/5) == 0:
            train_info = "batch:{}/{}, loss:{}, acc:{} ({}/{})"\
                  .format(batch_idx, len(dataloader), train_loss / (batch_idx + 1),
                          correct / total, correct, total)
            print(train_info)

利用可学习参数来进行知识蒸馏核心代码:

# 设置知识蒸馏的超参数
temp = 7      # 蒸馏温度
alpha = nn.Parameter(torch.tensor(0.3), requires_grad=True)   # 权重系数
gama  = nn.Parameter(torch.tensor(0.7), requires_grad=True)   # 权重系数
# params = [alpha, gama]     # 设置两个可学习参数
params = [alpha, ]           # 设置一个可学习参数
# print(params)
#  [Parameter containing:
#    tensor(0.3000, requires_grad=True)]
    
# 准备新的学生模型/优化器/损失函数
t_model = Teacher().to(device)	
t_model.load_state_dict(torch.load("./t_model.mdl"))   # 导入训练好的Tercher模型
kd_model = Student().to(device)					 # 导入待训练的Student模型
hard_loss = nn.CrossEntropyLoss()                # 包含softmax操作
soft_loss = nn.KLDivLoss(reduction="batchmean")  # 不包含softmax操作(所以可以自己设定温度系数)

# 定义蒸馏学生网络的优化器
# kd_optimizer = optim.Adam([
#     {'params': kd_model.parameters()},
#     {'params': alpha, 'lr':3e-4}
# ], lr=learning_rate)

# 由于alpha属于non-leaf Tensor, 以上代码是错误的,所以需要设计两个优化器
kd_optimizer = optim.Adam(kd_model.parameters(), lr=learning_rate)
parms_optimizer = optim.Adam(params, lr=3e-5)

# 利用知识蒸馏来训练学生网络
for epoch in range(epoch_size):
    print("[Epoch:{}]\n".format(epoch), "[params:alpha:{},gama:{}]".format(alpha, gama))
    train_one_epoch_kd(kd_model, t_model, hard_loss, soft_loss, kd_optimizer, parms_optimizer, train_loader)
    validate(kd_model, criterion, test_loader)

输出结果:

[Epoch:0]
 [params:alpha:0.30000001192092896,gama:0.699999988079071]
batch:0/469, loss:-0.6614086031913757, acc:0.0859375 (11/128)
batch:93/469, loss:-0.8821925444805876, acc:0.3480718085106383 (4188/12032)
batch:186/469, loss:-0.9977140796375784, acc:0.5088987299465241 (12181/23936)
batch:279/469, loss:-1.073494029045105, acc:0.6039341517857143 (21645/35840)
batch:372/469, loss:-1.1280986780135944, acc:0.6661151139410187 (31803/47744)
batch:465/469, loss:-1.1685462095195132, acc:0.7057571083690987 (42097/59648)
Test ==> loss:0.46250298619270325, acc:0.8812 (8812/10000)
[Epoch:1]
 [params:alpha:0.2877715528011322,gama:0.699999988079071]
batch:0/469, loss:-1.3155369758605957, acc:0.8203125 (105/128)
batch:93/469, loss:-1.3573756332093097, acc:0.8779089095744681 (10563/12032)
batch:186/469, loss:-1.367437780859636, acc:0.8814756016042781 (21099/23936)
batch:279/469, loss:-1.3752499222755432, acc:0.8830357142857143 (31648/35840)
batch:372/469, loss:-1.3843675500266355, acc:0.8863103217158177 (42316/47744)
batch:465/469, loss:-1.391783864713022, acc:0.8887640826180258 (53013/59648)
Test ==> loss:0.3379697990191134, acc:0.9072 (9072/10000)
[Epoch:2]
 [params:alpha:0.2754247784614563,gama:0.699999988079071]
batch:0/469, loss:-1.4017305374145508, acc:0.9140625 (117/128)
batch:93/469, loss:-1.4349181017977126, acc:0.9048371010638298 (10887/12032)
batch:186/469, loss:-1.440221704901221, acc:0.9034508689839572 (21625/23936)
batch:279/469, loss:-1.4445673780781882, acc:0.9034040178571429 (32378/35840)
batch:372/469, loss:-1.450090474801153, acc:0.9046791219839142 (43193/47744)
batch:465/469, loss:-1.456252665990412, acc:0.905730284334764 (54025/59648)
Test ==> loss:0.2985522945093203, acc:0.9149 (9149/10000)
[Epoch:3]
 [params:alpha:0.2624278962612152,gama:0.699999988079071]
batch:0/469, loss:-1.4728816747665405, acc:0.9140625 (117/128)
batch:93/469, loss:-1.4891292845949213, acc:0.9151429521276596 (11011/12032)
batch:186/469, loss:-1.4935716301362145, acc:0.9136864973262032 (21870/23936)
batch:279/469, loss:-1.4979333805186408, acc:0.9130301339285715 (32723/35840)
batch:372/469, loss:-1.5017808774840735, acc:0.9121565013404825 (43550/47744)
batch:465/469, loss:-1.5070314509674203, acc:0.9135930793991416 (54494/59648)
Test ==> loss:0.28384389652858805, acc:0.9218 (9218/10000)
[Epoch:4]
 [params:alpha:0.2490001767873764,gama:0.699999988079071]
batch:0/469, loss:-1.536271572113037, acc:0.921875 (118/128)
batch:93/469, loss:-1.536556749901873, acc:0.914311835106383 (11001/12032)
batch:186/469, loss:-1.5403413033102924, acc:0.915817179144385 (21921/23936)
batch:279/469, loss:-1.5454376697540284, acc:0.9179129464285715 (32898/35840)
batch:372/469, loss:-1.5496390562594415, acc:0.9183352882037533 (43845/47744)
batch:465/469, loss:-1.5540172580485692, acc:0.9185387607296137 (54789/59648)
Test ==> loss:0.29046644185540044, acc:0.9234 (9234/10000)

这里分别测试两个可学习参数与一个可学习参数训练的差别,我测试过使用一个可学习参数在前5个epoch中的效果是比使用两个可学习参数要好的,而且收敛得也快;同时利用在线的权重参数调整是要比手动设置知识蒸馏的参数在效果上是更好的。

但是,很可惜的是,在线权重调整的蒸馏学习还是没有比直接训练一个小的神经网络要好。我估计,这里的主要原因是数据集是MNIST太简单了,只是对数字进行10分类,所以简单训练的一个小模型的效果往往效果也不会太差,下次有机会会尝试一个困难点的数据集。