同步代码
This commit is contained in:
parent
49b2110b5f
commit
049155a8b9
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
data
|
data
|
||||||
*.pth
|
*.pth
|
||||||
*.pyc
|
*.pyc
|
||||||
|
logs
|
25
config.py
Normal file
25
config.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
# 模型参数
|
||||||
|
input_dim = 784
|
||||||
|
hidden_dim = 256
|
||||||
|
output_dim = 10
|
||||||
|
|
||||||
|
# 训练参数
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
batch_size = 64
|
||||||
|
epochs = 100
|
||||||
|
learning_rate = 0.001
|
||||||
|
save_path = "checkpoints/best_model.pth"
|
||||||
|
|
||||||
|
# 日志参数
|
||||||
|
log_dir = "logs"
|
||||||
|
log_level = "INFO"
|
||||||
|
|
||||||
|
# 断点续训
|
||||||
|
resume = False
|
||||||
|
checkpoint_path = "checkpoints/last_checkpoint.pth"
|
||||||
|
output_path = "runs/"
|
||||||
|
|
||||||
|
config = Config()
|
56
data_loader.py
Normal file
56
data_loader.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
class ClassifyDataset(Dataset):
|
||||||
|
def __init__(self, data_dir,transforms = None):
|
||||||
|
self.data_dir = data_dir
|
||||||
|
# Assume the dataset is structured with subdirectories for each class
|
||||||
|
self.transform = transforms
|
||||||
|
self.dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
|
||||||
|
self.image_size = (3, 224, 224)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
try:
|
||||||
|
image, label = self.dataset[idx]
|
||||||
|
return image, label
|
||||||
|
except Exception as e:
|
||||||
|
black_image = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||||
|
return self.transform(Image.fromarray(black_image)), 0 # -1 作为默认标签
|
||||||
|
|
||||||
|
def create_data_loaders(data_dir,batch_size=64):
|
||||||
|
# Define transformations for training data augmentation and normalization
|
||||||
|
train_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# Define transformations for validation and test data (only normalization)
|
||||||
|
valid_test_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# Load the datasets with ImageFolder
|
||||||
|
train_dir = data_dir + '/train'
|
||||||
|
valid_dir = data_dir + '/val'
|
||||||
|
test_dir = data_dir + '/test'
|
||||||
|
|
||||||
|
train_data = ClassifyDataset(train_dir, transforms=train_transforms)
|
||||||
|
valid_data = ClassifyDataset(valid_dir, transforms=valid_test_transforms)
|
||||||
|
test_data = ClassifyDataset(test_dir, transforms=valid_test_transforms)
|
||||||
|
|
||||||
|
# Create the DataLoaders with batch size 64
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)
|
||||||
|
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
|
||||||
|
|
||||||
|
return train_loader, valid_loader, test_loader
|
35
logger.py
Normal file
35
logger.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from config import config
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
def __init__(self):
|
||||||
|
os.makedirs(config.log_dir, exist_ok=True)
|
||||||
|
log_file = os.path.join(
|
||||||
|
config.log_dir,
|
||||||
|
f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.logger.setLevel(config.log_level)
|
||||||
|
|
||||||
|
# 文件输出
|
||||||
|
file_handler = logging.FileHandler(log_file)
|
||||||
|
file_formatter = logging.Formatter(
|
||||||
|
"%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
|
||||||
|
# 控制台输出
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_formatter = logging.Formatter("%(message)s")
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
|
||||||
|
self.logger.addHandler(file_handler)
|
||||||
|
self.logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
def log(self, level, message):
|
||||||
|
getattr(self.logger, level)(message)
|
||||||
|
|
||||||
|
logger = Logger()
|
25
main.py
Normal file
25
main.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from trainner import Trainer
|
||||||
|
from config import config
|
||||||
|
from logger import logger
|
||||||
|
from torch import optim, nn
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
from model.repvit import *
|
||||||
|
from data_loader import *
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 初始化组件
|
||||||
|
model = repvit_m1_1(num_classes=10).to(config.device)
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
train_loader, valid_loader, test_loader = create_data_loaders('F:/dataset/02.TA_EC/datasets/EC27')
|
||||||
|
|
||||||
|
# 初始化训练器
|
||||||
|
trainer = Trainer(model, train_loader, valid_loader, optimizer, criterion)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -19,7 +19,7 @@ def _make_divisible(v, divisor, min_value=None):
|
|||||||
new_v += divisor
|
new_v += divisor
|
||||||
return new_v
|
return new_v
|
||||||
|
|
||||||
from timm.models.layers import SqueezeExcite
|
from timm.layers import SqueezeExcite
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
22
predict.py
Normal file
22
predict.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
from model import create_model
|
||||||
|
from config import config
|
||||||
|
from utils import load_checkpoint
|
||||||
|
|
||||||
|
class Predictor:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = create_model()
|
||||||
|
load_checkpoint(self.model) # 加载最佳模型
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def predict(self, input_data):
|
||||||
|
with torch.no_grad():
|
||||||
|
input_tensor = torch.tensor(input_data).float().to(config.device)
|
||||||
|
output = self.model(input_tensor)
|
||||||
|
return output.argmax(dim=1).cpu().numpy()
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
predictor = Predictor()
|
||||||
|
sample_data = [...] # 输入数据
|
||||||
|
print("Prediction:", predictor.predict(sample_data))
|
81
trainner.py
Normal file
81
trainner.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from config import config
|
||||||
|
from logger import logger
|
||||||
|
from utils import save_checkpoint, load_checkpoint
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, model, train_loader, val_loader, optimizer, criterion):
|
||||||
|
self.model = model
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.criterion = criterion
|
||||||
|
self.start_epoch = 0
|
||||||
|
|
||||||
|
if config.resume:
|
||||||
|
self.start_epoch = load_checkpoint(model, optimizer)
|
||||||
|
|
||||||
|
def train_epoch(self, epoch):
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
|
||||||
|
|
||||||
|
for batch_idx, (data, target) in enumerate(progress_bar):
|
||||||
|
data, target = data.to(config.device), target.to(config.device)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
output = self.model(data)
|
||||||
|
loss = self.criterion(output, target)
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
progress_bar.set_postfix({
|
||||||
|
"Loss": f"{loss.item():.4f}",
|
||||||
|
"Avg Loss": f"{total_loss/(batch_idx+1):.4f}"
|
||||||
|
})
|
||||||
|
|
||||||
|
# 每100个batch记录一次日志
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
logger.log("info",
|
||||||
|
f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)}] "
|
||||||
|
f"Loss: {loss.item():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return total_loss / len(self.train_loader)
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, target in self.val_loader:
|
||||||
|
data, target = data.to(config.device), target.to(config.device)
|
||||||
|
output = self.model(data)
|
||||||
|
total_loss += self.criterion(output, target).item()
|
||||||
|
pred = output.argmax(dim=1)
|
||||||
|
correct += pred.eq(target).sum().item()
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(self.val_loader)
|
||||||
|
accuracy = 100. * correct / len(self.val_loader.dataset)
|
||||||
|
logger.log("info",
|
||||||
|
f"Validation - Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%"
|
||||||
|
)
|
||||||
|
return avg_loss, accuracy
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
best_acc = 0.0
|
||||||
|
for epoch in range(self.start_epoch, config.epochs):
|
||||||
|
train_loss = self.train_epoch(epoch)
|
||||||
|
val_loss, val_acc = self.validate()
|
||||||
|
|
||||||
|
# 保存最佳模型
|
||||||
|
if val_acc > best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
save_checkpoint(self.model, self.optimizer, epoch, is_best=True)
|
||||||
|
|
||||||
|
# 保存检查点
|
||||||
|
save_checkpoint(self.model, self.optimizer, epoch)
|
28
utils.py
Normal file
28
utils.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from os import path
|
||||||
|
import torch
|
||||||
|
from config import config
|
||||||
|
import logger
|
||||||
|
|
||||||
|
def save_checkpoint(model, optimizer, epoch, is_best=False):
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state": model.state_dict(),
|
||||||
|
"optimizer_state": optimizer.state_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_best:
|
||||||
|
torch.save(checkpoint, config.save_path)
|
||||||
|
torch.save(checkpoint, path.join(config.output_path,"checkpoints/last_checkpoint.pth"))
|
||||||
|
logger.log("info", f"Checkpoint saved at epoch {epoch}")
|
||||||
|
|
||||||
|
def load_checkpoint(model, optimizer=None):
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(config.checkpoint_path)
|
||||||
|
model.load_state_dict(checkpoint["model_state"])
|
||||||
|
if optimizer:
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||||
|
logger.log("info", f"Resuming training from epoch {checkpoint['epoch']}")
|
||||||
|
return checkpoint["epoch"]
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.log("warning", "No checkpoint found, starting from scratch")
|
||||||
|
return 0
|
Loading…
Reference in New Issue
Block a user