|
|
|
|
|
|
|
import safetensors.torch |
|
import safetensors |
|
import torch |
|
import os |
|
|
|
|
|
tensor_name_blacklist = [ |
|
'embed', |
|
'layer_norm', |
|
'norm.weight', |
|
'attention_bias', |
|
'vae' |
|
] |
|
|
|
|
|
tensor_name_whitelist = [ |
|
'double_layers' |
|
'single_layers' |
|
't5xl' |
|
] |
|
|
|
fn = input('Enter the path of an FP16 model file: ') |
|
|
|
if not os.path.exists(fn): |
|
raise FileNotFoundError(f"'{fn}' does not exist") |
|
if not os.path.isfile(fn): |
|
FileNotFoundError(f"'{fn}' is not a file") |
|
|
|
if 'fp16' in fn: |
|
output_fn = fn.replace('fp16', '8-bit') |
|
elif 'f16' in fn: |
|
output_fn = fn.replace('f16', '8-bit') |
|
elif 'FP16' in fn: |
|
output_fn = fn.replace('FP16', '8-bit') |
|
elif 'F16' in fn: |
|
output_fn = fn.replace('F16', '8-bit') |
|
else: |
|
output_fn = fn.replace('.safetensors', '-8-bit.safetensors') |
|
|
|
if os.path.exists(output_fn): |
|
raise FileExistsError(f"destination file {output_fn!r} already exists") |
|
|
|
def maybe_reduce_precision_tensor(tensor: torch.Tensor, tensor_name: str) -> torch.Tensor: |
|
""" |
|
Convert the given tensor to 8-bit if it is float16, otherwise do nothing |
|
""" |
|
|
|
if tensor.dtype not in [torch.float16, torch.half]: |
|
print(f"SKIP: tensor {tensor_name}: {tensor.dtype}") |
|
return tensor |
|
|
|
print(f"CAST -- tensor {tensor_name}: {tensor.dtype} -> torch.int8") |
|
return tensor.char() |
|
|
|
i = 0 |
|
fp16_tensors: dict[str, torch.Tensor] = {} |
|
|
|
with safetensors.safe_open(fn, framework="pt", device='mps') as f: |
|
for tensor_name in f.keys(): |
|
print(f"LOAD: tensor {tensor_name}") |
|
fp16_tensors[tensor_name] = f.get_tensor(tensor_name) |
|
i += 1 |
|
|
|
_8bit_tensors: dict[str, torch.Tensor] = {} |
|
for tensor_name in fp16_tensors.keys(): |
|
blacklist_flag = False |
|
for string in tensor_name_blacklist: |
|
if string in tensor_name: |
|
blacklist_flag = True |
|
break |
|
|
|
if not blacklist_flag: |
|
_8bit_tensors[tensor_name] = maybe_reduce_precision_tensor( |
|
tensor=fp16_tensors[tensor_name], |
|
tensor_name=tensor_name |
|
) |
|
else: |
|
print(f'COPY: tensor {tensor_name} is blacklisted') |
|
_8bit_tensors[tensor_name] = fp16_tensors[tensor_name] |
|
|
|
safetensors.torch.save_file( |
|
tensors=_8bit_tensors, |
|
filename=output_fn |
|
) |
|
|
|
print(f'saved 8-bit file to {output_fn}') |
|
|