TA_EC/predict.py
2025-03-10 00:31:37 +08:00

22 lines
681 B
Python

import torch
from model import create_model
from config import config
from utils import load_checkpoint
class Predictor:
def __init__(self):
self.model = create_model()
load_checkpoint(self.model) # 加载最佳模型
self.model.eval()
def predict(self, input_data):
with torch.no_grad():
input_tensor = torch.tensor(input_data).float().to(config.device)
output = self.model(input_tensor)
return output.argmax(dim=1).cpu().numpy()
# 使用示例
if __name__ == "__main__":
predictor = Predictor()
sample_data = [...] # 输入数据
print("Prediction:", predictor.predict(sample_data))