Spaces:
Sleeping
Sleeping
import io | |
import torch | |
def get_target_dtype_ref(target_dtype: str) -> torch.dtype: | |
if isinstance(target_dtype, torch.dtype): | |
return target_dtype | |
if target_dtype == "float16": | |
return torch.float16 | |
elif target_dtype == "float32": | |
return torch.float32 | |
elif target_dtype == "bfloat16": | |
return torch.bfloat16 | |
else: | |
raise ValueError(f"Invalid target_dtype: {target_dtype}") | |
def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict: | |
if isinstance(ckpt_upload, bytes): | |
ckpt_upload = io.BytesIO(ckpt_upload) | |
target_dtype = get_target_dtype_ref(target_dtype) | |
# Load the checkpoint | |
loaded_dict = torch.load(ckpt_upload, map_location="cpu") | |
tensor_dict = {} | |
is_embedding = 'string_to_param' in loaded_dict | |
if is_embedding: | |
emb_tensor = loaded_dict.get('string_to_param', {}).get('*', None) | |
if emb_tensor is not None: | |
emb_tensor = emb_tensor.to(dtype=target_dtype) | |
tensor_dict = { | |
'emb_params': emb_tensor | |
} | |
else: | |
# Convert weights in a checkpoint to a dictionary of tensors | |
for key, val in loaded_dict.items(): | |
if isinstance(val, torch.Tensor): | |
tensor_dict[key] = val.to(dtype=target_dtype) | |
return tensor_dict | |
if __name__ == '__main__': | |
print('__main__ not allowed in modules') | |