lora_tools / tools /torch_tools.py
kjerk
Fix embedding reparser regression
d4615af
raw
history blame
1.31 kB
import io
import torch
def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
if isinstance(target_dtype, torch.dtype):
return target_dtype
if target_dtype == "float16":
return torch.float16
elif target_dtype == "float32":
return torch.float32
elif target_dtype == "bfloat16":
return torch.bfloat16
else:
raise ValueError(f"Invalid target_dtype: {target_dtype}")
def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
if isinstance(ckpt_upload, bytes):
ckpt_upload = io.BytesIO(ckpt_upload)
target_dtype = get_target_dtype_ref(target_dtype)
# Load the checkpoint
loaded_dict = torch.load(ckpt_upload, map_location="cpu")
tensor_dict = {}
is_embedding = 'string_to_param' in loaded_dict
if is_embedding:
emb_tensor = loaded_dict.get('string_to_param', {}).get('*', None)
if emb_tensor is not None:
emb_tensor = emb_tensor.to(dtype=target_dtype)
tensor_dict = {
'emb_params': emb_tensor
}
else:
# Convert weights in a checkpoint to a dictionary of tensors
for key, val in loaded_dict.items():
if isinstance(val, torch.Tensor):
tensor_dict[key] = val.to(dtype=target_dtype)
return tensor_dict
if __name__ == '__main__':
print('__main__ not allowed in modules')