elia / checkpoints /test.py
yxchng
add files
a166479
raw
history blame
No virus
348 Bytes
import torch
model = torch.load('model_best_refcoco_0508.pth', map_location='cpu')
print(model['model'].keys())
new_dict = {}
for k in model['model'].keys():
if 'image_model' in k or 'language_model' in k or 'classifier' in k:
new_dict[k] = model['model'][k]
#torch.save('gradio.pth', new_dict)
torch.save(new_dict, 'gradio.pth')