同步代码
This commit is contained in:
parent
49b2110b5f
commit
049155a8b9
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
data
|
||||
*.pth
|
||||
*.pyc
|
||||
logs
|
25
config.py
Normal file
25
config.py
Normal file
@ -0,0 +1,25 @@
|
||||
import torch
|
||||
|
||||
class Config:
|
||||
# 模型参数
|
||||
input_dim = 784
|
||||
hidden_dim = 256
|
||||
output_dim = 10
|
||||
|
||||
# 训练参数
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
batch_size = 64
|
||||
epochs = 100
|
||||
learning_rate = 0.001
|
||||
save_path = "checkpoints/best_model.pth"
|
||||
|
||||
# 日志参数
|
||||
log_dir = "logs"
|
||||
log_level = "INFO"
|
||||
|
||||
# 断点续训
|
||||
resume = False
|
||||
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
||||
output_path = "runs/"
|
||||
|
||||
config = Config()
|
56
data_loader.py
Normal file
56
data_loader.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class ClassifyDataset(Dataset):
|
||||
def __init__(self, data_dir,transforms = None):
|
||||
self.data_dir = data_dir
|
||||
# Assume the dataset is structured with subdirectories for each class
|
||||
self.transform = transforms
|
||||
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
|
||||
self.image_size = (3, 224, 224)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
image, label = self.dataset[idx]
|
||||
return image, label
|
||||
except Exception as e:
|
||||
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
|
||||
|
||||
def create_data_loaders(data_dir,batch_size=64):
|
||||
# Define transformations for training data augmentation and normalization
|
||||
train_transforms = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Define transformations for validation and test data (only normalization)
|
||||
valid_test_transforms = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Load the datasets with ImageFolder
|
||||
train_dir = data_dir + '/train'
|
||||
valid_dir = data_dir + '/val'
|
||||
test_dir = data_dir + '/test'
|
||||
|
||||
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
|
||||
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
|
||||
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
|
||||
|
||||
# Create the DataLoaders with batch size 64
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
|
||||
|
||||
return train_loader, valid_loader, test_loader
|
35
logger.py
Normal file
35
logger.py
Normal file
@ -0,0 +1,35 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from config import config
|
||||
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
os.makedirs(config.log_dir, exist_ok=True)
|
||||
log_file = os.path.join(
|
||||
config.log_dir,
|
||||
f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||
)
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(config.log_level)
|
||||
|
||||
# 文件输出
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
|
||||
# 控制台输出
|
||||
console_handler = logging.StreamHandler()
|
||||
console_formatter = logging.Formatter("%(message)s")
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
self.logger.addHandler(file_handler)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def log(self, level, message):
|
||||
getattr(self.logger, level)(message)
|
||||
|
||||
logger = Logger()
|
25
main.py
Normal file
25
main.py
Normal file
@ -0,0 +1,25 @@
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from trainner import Trainer
|
||||
from config import config
|
||||
from logger import logger
|
||||
from torch import optim, nn
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from model.repvit import *
|
||||
from data_loader import *
|
||||
|
||||
def main():
|
||||
# 初始化组件
|
||||
model = repvit_m1_1(num_classes=10).to(config.device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27')
|
||||
|
||||
# 初始化训练器
|
||||
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||
trainer.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -19,7 +19,7 @@ def _make_divisible(v, divisor, min_value=None):
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
from timm.models.layers import SqueezeExcite
|
||||
from timm.layers import SqueezeExcite
|
||||
|
||||
import torch
|
||||
|
||||
|
22
predict.py
Normal file
22
predict.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from model import create_model
|
||||
from config import config
|
||||
from utils import load_checkpoint
|
||||
|
||||
class Predictor:
|
||||
def __init__(self):
|
||||
self.model = create_model()
|
||||
load_checkpoint(self.model) # 加载最佳模型
|
||||
self.model.eval()
|
||||
|
||||
def predict(self, input_data):
|
||||
with torch.no_grad():
|
||||
input_tensor = torch.tensor(input_data).float().to(config.device)
|
||||
output = self.model(input_tensor)
|
||||
return output.argmax(dim=1).cpu().numpy()
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
predictor = Predictor()
|
||||
sample_data = [...] # 输入数据
|
||||
print("Prediction:", predictor.predict(sample_data))
|
81
trainner.py
Normal file
81
trainner.py
Normal file
@ -0,0 +1,81 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
from config import config
|
||||
from logger import logger
|
||||
from utils import save_checkpoint, load_checkpoint
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.start_epoch = 0
|
||||
|
||||
if config.resume:
|
||||
self.start_epoch = load_checkpoint(model, optimizer)
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
||||
|
||||
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||
data, target = data.to(config.device), target.to(config.device)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(data)
|
||||
loss = self.criterion(output, target)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
progress_bar.set_postfix({
|
||||
"Loss": f"{loss.item():.4f}",
|
||||
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
|
||||
})
|
||||
|
||||
# 每100个batch记录一次日志
|
||||
if batch_idx % 100 == 0:
|
||||
logger.log("info",
|
||||
f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
|
||||
f"Loss: {loss.item():.4f}"
|
||||
)
|
||||
|
||||
return total_loss / len(self.train_loader)
|
||||
|
||||
def validate(self):
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in self.val_loader:
|
||||
data, target = data.to(config.device), target.to(config.device)
|
||||
output = self.model(data)
|
||||
total_loss += self.criterion(output, target).item()
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target).sum().item()
|
||||
|
||||
avg_loss = total_loss / len(self.val_loader)
|
||||
accuracy = 100. * correct / len(self.val_loader.dataset)
|
||||
logger.log("info",
|
||||
f"Validation - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%"
|
||||
)
|
||||
return avg_loss, accuracy
|
||||
|
||||
def train(self):
|
||||
best_acc = 0.0
|
||||
for epoch in range(self.start_epoch, config.epochs):
|
||||
train_loss = self.train_epoch(epoch)
|
||||
val_loss, val_acc = self.validate()
|
||||
|
||||
# 保存最佳模型
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
save_checkpoint(self.model, self.optimizer, epoch, is_best=True)
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(self.model, self.optimizer, epoch)
|
28
utils.py
Normal file
28
utils.py
Normal file
@ -0,0 +1,28 @@
|
||||
from os import path
|
||||
import torch
|
||||
from config import config
|
||||
import logger
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict()
|
||||
}
|
||||
|
||||
if is_best:
|
||||
torch.save(checkpoint, config.save_path)
|
||||
torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth"))
|
||||
logger.log("info", f"Checkpoint saved at epoch {epoch}")
|
||||
|
||||
def load_checkpoint(model, optimizer=None):
|
||||
try:
|
||||
checkpoint = torch.load(config.checkpoint_path)
|
||||
model.load_state_dict(checkpoint["model_state"])
|
||||
if optimizer:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}")
|
||||
return checkpoint["epoch"]
|
||||
except FileNotFoundError:
|
||||
logger.log("warning", "No checkpoint found, starting from scratch")
|
||||
return 0
|
Loading…
Reference in New Issue
Block a user