TA_EC/test.py
2025-03-09 22:36:22 +08:00

140 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 超参数设置
num_epochs_teacher = 10 # 教师模型训练轮数
num_epochs_student = 20 # 学生模型训练轮数
batch_size = 64
learning_rate = 0.001
temperature = 5 # 蒸馏温度
alpha = 0.3 # 蒸馏损失权重
# 数据集准备示例使用CIFAR-10
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 假设:
# 数据集A是CIFAR-10训练集的前25000张
# 数据集B是CIFAR-10训练集的后25000张
dataset_A = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
dataset_A = torch.utils.data.Subset(dataset_A, range(25000))
dataset_B = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
dataset_B = torch.utils.data.Subset(dataset_B, range(25000, 50000))
train_loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
train_loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True)
# 教师模型定义ResNet18
class TeacherModel(nn.Module):
def __init__(self):
super(TeacherModel, self).__init__()
self.resnet = torchvision.models.resnet18(pretrained=False)
self.resnet.fc = nn.Linear(512, 10) # CIFAR-10有10个类别
def forward(self, x):
return self.resnet(x)
# 学生模型定义更小的CNN
class StudentModel(nn.Module):
def __init__(self):
super(StudentModel, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Linear(32 * 8 * 8, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# 训练教师模型
teacher = TeacherModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher.parameters(), lr=learning_rate)
print("Training Teacher Model...")
for epoch in range(num_epochs_teacher):
teacher.train()
for images, labels in train_loader_A:
images = images.to(device)
labels = labels.to(device)
outputs = teacher(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Teacher Epoch [{epoch+1}/{num_epochs_teacher}]")
# 知识蒸馏训练学生模型
student = StudentModel().to(device)
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
criterion_hard = nn.CrossEntropyLoss()
criterion_soft = nn.KLDivLoss(reduction="batchmean")
print("\nDistilling Knowledge to Student...")
teacher.eval() # 设置教师模型为评估模式
for epoch in range(num_epochs_student):
student.train()
total_loss = 0
for images, labels in train_loader_B:
images = images.to(device)
labels = labels.to(device)
# 教师模型预测(不计算梯度)
with torch.no_grad():
teacher_logits = teacher(images)
# 学生模型预测
student_logits = student(images)
# 计算硬标签损失
hard_loss = criterion_hard(student_logits, labels)
# 计算软标签损失(带温度缩放)
soft_loss = criterion_soft(
nn.functional.log_softmax(student_logits / temperature, dim=1),
nn.functional.softmax(teacher_logits / temperature, dim=1)
) * (temperature ** 2) # 缩放梯度
# 组合损失
loss = alpha * hard_loss + (1 - alpha) * soft_loss
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader_B)
print(f"Student Epoch [{epoch+1}/{num_epochs_student}], Loss: {avg_loss:.4f}")
print("Knowledge distillation complete!")