Spaces:
Paused
Paused
import gc | |
import os | |
from collections import OrderedDict | |
from typing import ForwardRef | |
import torch | |
from safetensors.torch import save_file, load_file | |
from jobs.process.BaseProcess import BaseProcess | |
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ | |
add_base_model_info_to_meta | |
from toolkit.train_tools import get_torch_dtype | |
class ModRescaleLoraProcess(BaseProcess): | |
process_id: int | |
config: OrderedDict | |
progress_bar: ForwardRef('tqdm') = None | |
def __init__( | |
self, | |
process_id: int, | |
job, | |
config: OrderedDict | |
): | |
super().__init__(process_id, job, config) | |
self.process_id: int | |
self.config: OrderedDict | |
self.progress_bar: ForwardRef('tqdm') = None | |
self.input_path = self.get_conf('input_path', required=True) | |
self.output_path = self.get_conf('output_path', required=True) | |
self.replace_meta = self.get_conf('replace_meta', default=False) | |
self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype) | |
self.current_weight = self.get_conf('current_weight', required=True, as_type=float) | |
self.target_weight = self.get_conf('target_weight', required=True, as_type=float) | |
self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down | |
self.is_xl = self.get_conf('is_xl', default=False, as_type=bool) | |
self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool) | |
self.progress_bar = None | |
def run(self): | |
super().run() | |
source_state_dict = load_file(self.input_path) | |
source_meta = load_metadata_from_safetensors(self.input_path) | |
if self.replace_meta: | |
self.meta.update( | |
add_base_model_info_to_meta( | |
self.meta, | |
is_xl=self.is_xl, | |
is_v2=self.is_v2, | |
) | |
) | |
save_meta = get_meta_for_safetensors(self.meta, self.job.name) | |
else: | |
save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False) | |
# save | |
os.makedirs(os.path.dirname(self.output_path), exist_ok=True) | |
new_state_dict = OrderedDict() | |
for key in list(source_state_dict.keys()): | |
v = source_state_dict[key] | |
v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32')) | |
# all loras have an alpha, up weight and down weight | |
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha", | |
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight", | |
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight", | |
# we can rescale by adjusting the alpha or the up weights, or the up and down weights | |
# I assume doing both up and down would be best all around, but I'm not sure | |
# some locons also have mid weights, we will leave those alone for now, will work without them | |
# when adjusting alpha, it is used to calculate the multiplier in a lora module | |
# - scale = alpha / lora_dim | |
# - output = layer_out + lora_up_out * multiplier * scale | |
total_module_scale = torch.tensor(self.current_weight / self.target_weight) \ | |
.to("cpu", dtype=get_torch_dtype('fp32')) | |
num_modules_layers = 2 # up and down | |
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ | |
.to("cpu", dtype=get_torch_dtype('fp32')) | |
# only update alpha | |
if self.scale_target == 'alpha' and key.endswith('.alpha'): | |
v = v * total_module_scale | |
if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): | |
# would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN | |
v = v * up_down_scale | |
v = v.detach().clone().to("cpu").to(self.save_dtype) | |
new_state_dict[key] = v | |
save_meta = add_model_hash_to_meta(new_state_dict, save_meta) | |
save_file(new_state_dict, self.output_path, save_meta) | |
# cleanup incase there are other jobs | |
del new_state_dict | |
del source_state_dict | |
del source_meta | |
torch.cuda.empty_cache() | |
gc.collect() | |
print(f"Saved to {self.output_path}") | |