import torch from model import create_model from config import config from utils import load_checkpoint class Predictor: def __init__(self): self.model = create_model() load_checkpoint(self.model) # 加载最佳模型 self.model.eval() def predict(self, input_data): with torch.no_grad(): input_tensor = torch.tensor(input_data).float().to(config.device) output = self.model(input_tensor) return output.argmax(dim=1).cpu().numpy() # 使用示例 if __name__ == "__main__": predictor = Predictor() sample_data = [...] # 输入数据 print("Prediction:", predictor.predict(sample_data))