|
from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor |
|
from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys |
|
|
|
from segment_anything import sam_model_registry |
|
|
|
|
|
checkpoint = 'medsam_vit_b.pth' |
|
pt_model = sam_model_registry['vit_b'](checkpoint) |
|
pt_state_dict = pt_model.state_dict() |
|
|
|
|
|
hf_state_dict = replace_keys(pt_state_dict) |
|
|
|
|
|
hf_model = SamModel(config=SamConfig()) |
|
hf_model.load_state_dict(hf_state_dict) |
|
hf_model.save_pretrained('./') |
|
|
|
|
|
hf_processor = SamProcessor( |
|
image_processor=SamImageProcessor( |
|
do_normalize=False, |
|
image_mean=[0, 0, 0], |
|
image_std=[1, 1, 1], |
|
resample=3, |
|
) |
|
) |
|
|
|
|
|
hf_processor.save_pretrained('./') |
|
|