Spaces:
Running
Running
import io | |
import torch | |
def get_target_dtype_ref(target_dtype: str) -> torch.dtype: | |
if isinstance(target_dtype, torch.dtype): | |
return target_dtype | |
if target_dtype == "float16": | |
return torch.float16 | |
elif target_dtype == "float32": | |
return torch.float32 | |
elif target_dtype == "bfloat16": | |
return torch.bfloat16 | |
else: | |
raise ValueError(f"Invalid target_dtype: {target_dtype}") | |
def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict: | |
target_dtype = get_target_dtype_ref(target_dtype) | |
ckpt_data = ckpt_upload.getvalue() | |
# Load the checkpoint | |
checkpoint = torch.load(ckpt_data, map_location="cpu") | |
# Convert the checkpoint to a dictionary of tensors | |
tensor_dict = {} | |
for key, val in checkpoint.items(): | |
tensor_dict[key] = val.to(dtype=target_dtype) | |
return tensor_dict | |
if __name__ == '__main__': | |
print('__main__ not allowed in modules') | |