2025-03-04 04:55:40 +00:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
import torchvision
|
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
|
|
|
|
# 设备配置
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
# 超参数设置
|
|
|
|
|
num_epochs_teacher = 10 # 教师模型训练轮数
|
|
|
|
|
num_epochs_student = 20 # 学生模型训练轮数
|
|
|
|
|
batch_size = 64
|
|
|
|
|
learning_rate = 0.001
|
|
|
|
|
temperature = 5 # 蒸馏温度
|
|
|
|
|
alpha = 0.3 # 蒸馏损失权重
|
|
|
|
|
|
|
|
|
|
# 数据集准备(示例使用CIFAR-10)
|
|
|
|
|
transform = transforms.Compose([
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# 假设:
|
|
|
|
|
# 数据集A是CIFAR-10训练集的前25000张
|
|
|
|
|
# 数据集B是CIFAR-10训练集的后25000张
|
|
|
|
|
dataset_A = torchvision.datasets.CIFAR10(
|
|
|
|
|
root='./data', train=True, download=True, transform=transform)
|
|
|
|
|
dataset_A = torch.utils.data.Subset(dataset_A, range(25000))
|
|
|
|
|
|
|
|
|
|
dataset_B = torchvision.datasets.CIFAR10(
|
|
|
|
|
root='./data', train=True, download=True, transform=transform)
|
|
|
|
|
dataset_B = torch.utils.data.Subset(dataset_B, range(25000, 50000))
|
|
|
|
|
|
|
|
|
|
train_loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
|
|
|
|
|
train_loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True)
|
|
|
|
|
|
|
|
|
|
# 教师模型定义(ResNet18)
|
|
|
|
|
class TeacherModel(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(TeacherModel, self).__init__()
|
|
|
|
|
self.resnet = torchvision.models.resnet18(pretrained=False)
|
|
|
|
|
self.resnet.fc = nn.Linear(512, 10) # CIFAR-10有10个类别
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.resnet(x)
|
|
|
|
|
|
|
|
|
|
# 学生模型定义(更小的CNN)
|
|
|
|
|
class StudentModel(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(StudentModel, self).__init__()
|
|
|
|
|
self.features = nn.Sequential(
|
|
|
|
|
nn.Conv2d(3, 16, 3, padding=1),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.MaxPool2d(2),
|
|
|
|
|
nn.Conv2d(16, 32, 3, padding=1),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.MaxPool2d(2)
|
|
|
|
|
)
|
|
|
|
|
self.classifier = nn.Sequential(
|
|
|
|
|
nn.Linear(32 * 8 * 8, 128),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.Linear(128, 10)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.features(x)
|
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
|
x = self.classifier(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
# 训练教师模型
|
|
|
|
|
teacher = TeacherModel().to(device)
|
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
optimizer = optim.Adam(teacher.parameters(), lr=learning_rate)
|
|
|
|
|
|
|
|
|
|
print("Training Teacher Model...")
|
|
|
|
|
for epoch in range(num_epochs_teacher):
|
|
|
|
|
teacher.train()
|
|
|
|
|
for images, labels in train_loader_A:
|
|
|
|
|
images = images.to(device)
|
|
|
|
|
labels = labels.to(device)
|
|
|
|
|
|
|
|
|
|
outputs = teacher(images)
|
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
print(f"Teacher Epoch [{epoch+1}/{num_epochs_teacher}]")
|
|
|
|
|
|
|
|
|
|
# 知识蒸馏训练学生模型
|
|
|
|
|
student = StudentModel().to(device)
|
|
|
|
|
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
|
|
|
|
|
criterion_hard = nn.CrossEntropyLoss()
|
|
|
|
|
criterion_soft = nn.KLDivLoss(reduction="batchmean")
|
|
|
|
|
|
|
|
|
|
print("\nDistilling Knowledge to Student...")
|
|
|
|
|
teacher.eval() # 设置教师模型为评估模式
|
|
|
|
|
|
|
|
|
|
for epoch in range(num_epochs_student):
|
|
|
|
|
student.train()
|
|
|
|
|
total_loss = 0
|
|
|
|
|
|
|
|
|
|
for images, labels in train_loader_B:
|
|
|
|
|
images = images.to(device)
|
|
|
|
|
labels = labels.to(device)
|
|
|
|
|
|
|
|
|
|
# 教师模型预测(不计算梯度)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
teacher_logits = teacher(images)
|
|
|
|
|
|
|
|
|
|
# 学生模型预测
|
|
|
|
|
student_logits = student(images)
|
|
|
|
|
|
|
|
|
|
# 计算硬标签损失
|
|
|
|
|
hard_loss = criterion_hard(student_logits, labels)
|
|
|
|
|
|
|
|
|
|
# 计算软标签损失(带温度缩放)
|
|
|
|
|
soft_loss = criterion_soft(
|
|
|
|
|
nn.functional.log_softmax(student_logits / temperature, dim=1),
|
|
|
|
|
nn.functional.softmax(teacher_logits / temperature, dim=1)
|
|
|
|
|
) * (temperature ** 2) # 缩放梯度
|
|
|
|
|
|
|
|
|
|
# 组合损失
|
|
|
|
|
loss = alpha * hard_loss + (1 - alpha) * soft_loss
|
|
|
|
|
|
|
|
|
|
# 反向传播
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
|
|
avg_loss = total_loss / len(train_loader_B)
|
|
|
|
|
print(f"Student Epoch [{epoch+1}/{num_epochs_student}], Loss: {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
|
print("Knowledge distillation complete!")
|