22 lines
681 B
Python
22 lines
681 B
Python
import torch
|
|
from model import create_model
|
|
from config import config
|
|
from utils import load_checkpoint
|
|
|
|
class Predictor:
|
|
def __init__(self):
|
|
self.model = create_model()
|
|
load_checkpoint(self.model) # 加载最佳模型
|
|
self.model.eval()
|
|
|
|
def predict(self, input_data):
|
|
with torch.no_grad():
|
|
input_tensor = torch.tensor(input_data).float().to(config.device)
|
|
output = self.model(input_tensor)
|
|
return output.argmax(dim=1).cpu().numpy()
|
|
|
|
# 使用示例
|
|
if __name__ == "__main__":
|
|
predictor = Predictor()
|
|
sample_data = [...] # 输入数据
|
|
print("Prediction:", predictor.predict(sample_data)) |