AuraFlow-8bit / safetensors_convert_fp16_to_8bit.py
ddh0's picture
Upload safetensors_convert_fp16_to_8bit.py
71292b4 verified
raw
history blame contribute delete
No virus
2.61 kB
# safetensors_convert_fp16_to_8bit.py
# Python 3.11.6
import safetensors.torch
import safetensors
import torch
import os
# list of substrings that cause a tensor to always be copied, not cast
tensor_name_blacklist = [
'embed',
'layer_norm',
'norm.weight',
'attention_bias',
'vae'
]
# list of substrings that cause a tensor to be cast
tensor_name_whitelist = [
'double_layers'
'single_layers'
't5xl'
]
fn = input('Enter the path of an FP16 model file: ')
if not os.path.exists(fn):
raise FileNotFoundError(f"'{fn}' does not exist")
if not os.path.isfile(fn):
FileNotFoundError(f"'{fn}' is not a file")
if 'fp16' in fn:
output_fn = fn.replace('fp16', '8-bit')
elif 'f16' in fn:
output_fn = fn.replace('f16', '8-bit')
elif 'FP16' in fn:
output_fn = fn.replace('FP16', '8-bit')
elif 'F16' in fn:
output_fn = fn.replace('F16', '8-bit')
else:
output_fn = fn.replace('.safetensors', '-8-bit.safetensors')
if os.path.exists(output_fn):
raise FileExistsError(f"destination file {output_fn!r} already exists")
def maybe_reduce_precision_tensor(tensor: torch.Tensor, tensor_name: str) -> torch.Tensor:
"""
Convert the given tensor to 8-bit if it is float16, otherwise do nothing
"""
# do not cast tensors that are not fp16
if tensor.dtype not in [torch.float16, torch.half]:
print(f"SKIP: tensor {tensor_name}: {tensor.dtype}")
return tensor
# fp16 tensor -> 8-bit tensor
print(f"CAST -- tensor {tensor_name}: {tensor.dtype} -> torch.int8")
return tensor.char()
i = 0
fp16_tensors: dict[str, torch.Tensor] = {}
# change `device='mps'` to `device='cpu'` if you are not using Metal
with safetensors.safe_open(fn, framework="pt", device='mps') as f:
for tensor_name in f.keys():
print(f"LOAD: tensor {tensor_name}")
fp16_tensors[tensor_name] = f.get_tensor(tensor_name)
i += 1
_8bit_tensors: dict[str, torch.Tensor] = {}
for tensor_name in fp16_tensors.keys():
blacklist_flag = False
for string in tensor_name_blacklist:
if string in tensor_name:
blacklist_flag = True
break
if not blacklist_flag:
_8bit_tensors[tensor_name] = maybe_reduce_precision_tensor(
tensor=fp16_tensors[tensor_name],
tensor_name=tensor_name
)
else:
print(f'COPY: tensor {tensor_name} is blacklisted')
_8bit_tensors[tensor_name] = fp16_tensors[tensor_name]
safetensors.torch.save_file(
tensors=_8bit_tensors,
filename=output_fn
)
print(f'saved 8-bit file to {output_fn}')