import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms # 设备配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 超参数设置 num_epochs_teacher = 10 # 教师模型训练轮数 num_epochs_student = 20 # 学生模型训练轮数 batch_size = 64 learning_rate = 0.001 temperature = 5 # 蒸馏温度 alpha = 0.3 # 蒸馏损失权重 # 数据集准备(示例使用CIFAR-10) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 假设: # 数据集A是CIFAR-10训练集的前25000张 # 数据集B是CIFAR-10训练集的后25000张 dataset_A = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform) dataset_A = torch.utils.data.Subset(dataset_A, range(25000)) dataset_B = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform) dataset_B = torch.utils.data.Subset(dataset_B, range(25000, 50000)) train_loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True) train_loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True) # 教师模型定义(ResNet18) class TeacherModel(nn.Module): def __init__(self): super(TeacherModel, self).__init__() self.resnet = torchvision.models.resnet18(pretrained=False) self.resnet.fc = nn.Linear(512, 10) # CIFAR-10有10个类别 def forward(self, x): return self.resnet(x) # 学生模型定义(更小的CNN) class StudentModel(nn.Module): def __init__(self): super(StudentModel, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier = nn.Sequential( nn.Linear(32 * 8 * 8, 128), nn.ReLU(), nn.Linear(128, 10) ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x # 训练教师模型 teacher = TeacherModel().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(teacher.parameters(), lr=learning_rate) print("Training Teacher Model...") for epoch in range(num_epochs_teacher): teacher.train() for images, labels in train_loader_A: images = images.to(device) labels = labels.to(device) outputs = teacher(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Teacher Epoch [{epoch+1}/{num_epochs_teacher}]") # 知识蒸馏训练学生模型 student = StudentModel().to(device) optimizer = optim.Adam(student.parameters(), lr=learning_rate) criterion_hard = nn.CrossEntropyLoss() criterion_soft = nn.KLDivLoss(reduction="batchmean") print("\nDistilling Knowledge to Student...") teacher.eval() # 设置教师模型为评估模式 for epoch in range(num_epochs_student): student.train() total_loss = 0 for images, labels in train_loader_B: images = images.to(device) labels = labels.to(device) # 教师模型预测(不计算梯度) with torch.no_grad(): teacher_logits = teacher(images) # 学生模型预测 student_logits = student(images) # 计算硬标签损失 hard_loss = criterion_hard(student_logits, labels) # 计算软标签损失(带温度缩放) soft_loss = criterion_soft( nn.functional.log_softmax(student_logits / temperature, dim=1), nn.functional.softmax(teacher_logits / temperature, dim=1) ) * (temperature ** 2) # 缩放梯度 # 组合损失 loss = alpha * hard_loss + (1 - alpha) * soft_loss # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(train_loader_B) print(f"Student Epoch [{epoch+1}/{num_epochs_student}], Loss: {avg_loss:.4f}") print("Knowledge distillation complete!")