36 lines
1.6 KiB
Python
36 lines
1.6 KiB
Python
![]() |
import warnings
|
||
|
warnings.filterwarnings('ignore')
|
||
|
import torch
|
||
|
from ultralytics.nn.tasks import DetectionModel
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
model_PGI_weights_path = 'runs/train/yolov8n-PGI/weights/best.pt'
|
||
|
model_cfg_path = "ultralytics/cfg/models/v8/yolov8n.yaml"
|
||
|
layer_num, pgi_layer_num = 22, 38
|
||
|
|
||
|
device = torch.device("cpu")
|
||
|
model_PGI = torch.load(model_PGI_weights_path, map_location='cpu')
|
||
|
model_name_key = 'model' if model_PGI['model'] is not None else 'ema'
|
||
|
model_PGI_dict = model_PGI[model_name_key].model.state_dict()
|
||
|
model_PGI_head = model_PGI[model_name_key].model[-1]
|
||
|
model_name = model_PGI[model_name_key].names
|
||
|
model = DetectionModel(model_cfg_path, nc=model_PGI_head.nc)
|
||
|
model.names = model_name
|
||
|
model_dict = model.state_dict()
|
||
|
|
||
|
new_dict = {}
|
||
|
for name in model_PGI_dict:
|
||
|
layer_id = int(name.split('.')[0]) - 1
|
||
|
new_name = f'.'.join(['model', str(layer_id)] + name.split('.')[1:])
|
||
|
if new_name in model_dict and model_PGI_dict[name].size() == model_dict[new_name].size():
|
||
|
new_dict[new_name] = model_PGI_dict[name]
|
||
|
|
||
|
if (layer_id + 1) == pgi_layer_num:
|
||
|
new_name = f'.'.join(['model', str(layer_num)] + name.split('.')[1:])
|
||
|
if new_name in model_dict and model_PGI_dict[name].size() == model_dict[new_name].size():
|
||
|
new_dict[new_name] = model_PGI_dict[name]
|
||
|
|
||
|
print(len(new_dict), len(model_dict))
|
||
|
model.load_state_dict(new_dict)
|
||
|
model.eval()
|
||
|
torch.save({'model':model.half()}, f'{model_PGI_weights_path[:model_PGI_weights_path.rfind(".")]}_rep.pt')
|