diffusers_dare_merger / merge_gr.py
John6666's picture
Upload 9 files
3a9c7c3 verified
from pathlib import Path
import os
import shutil
import yaml
import gradio as gr
from hf_merge import process_repos, repo_list_generator
def list_sub(a, b):
return [e for e in a if e not in b]
def is_repo_name(s):
import re
return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s)
def is_valid_model_name(s):
if is_repo_name(s) or Path(s).suffix in (".safetensors", ".bin", ".sft"): return True
else: return False
def is_repo_exists(repo_id):
from huggingface_hub import HfApi
api = HfApi()
try:
if api.repo_exists(repo_id=repo_id): return True
else: return False
except Exception as e:
print(f"Error: Failed to connect {repo_id}. ")
return True # for safe
def create_repo(new_repo_id):
from huggingface_hub import HfApi
import os
hf_token = os.environ.get("HF_TOKEN")
api = HfApi()
try:
api.create_repo(repo_id=new_repo_id, token=hf_token, private=True)
url = f"https://huggingface.co/{new_repo_id}"
except Exception as e:
print(f"Error: Failed to create {new_repo_id}. ")
print(e)
return ""
return url
def upload_dir_to_repo(new_repo_id, folder, progress=gr.Progress(track_tqdm=True)):
from huggingface_hub import HfApi
import os
hf_token = os.environ.get("HF_TOKEN")
api = HfApi()
try:
progress(0, desc="Start uploading...")
for path in Path(folder).glob("*"):
if path.is_dir():
api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
elif path.is_file():
api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
progress(1, desc="Uploaded.")
url = f"https://huggingface.co/{new_repo_id}"
except Exception as e:
print(f"Error: Failed to upload to {new_repo_id}. ")
print(e)
return ""
return url
merge_yaml_path = "./merge_yaml.yaml"
merge_text_path = "./merge_txt.txt"
def load_yaml_dict(yaml_path: str):
yaml_dict = None
try:
data = None
with open(yaml_path, 'r', encoding='utf-8') as file:
data = yaml.safe_load(file)
except Exception as e:
print(e)
data = None
if isinstance(data, dict) and 'models' in data.keys() and data['models']:
yaml_dict = data
return yaml_dict
def repo_text_to_yaml_dict(text_path: str, default_p: float, default_lambda_val: float):
yaml_dict = {}
repos = list(repo_list_generator(text_path, default_p, default_lambda_val))
yaml_dict.setdefault('models', {})
for repo in repos:
model, weight, density = repo
if not is_valid_model_name(model): continue
model_info = {}
model_info['model'] = str(model)
model_info.setdefault('parameters', {})
model_info['parameters']['weight'] = float(weight)
model_info['parameters']['density'] = float(density)
yaml_dict['models'][str(model.split("/")[-1])] = model_info
return yaml_dict
def gen_repo_list(input_text: str, default_p: float, default_lambda_val: float):
yaml_dict = {}
if Path(merge_yaml_path).exists():
yaml_dict = load_yaml_dict(merge_yaml_path)
else:
with open(merge_text_path, mode='w', encoding='utf-8') as file:
file.write(input_text)
yaml_dict = repo_text_to_yaml_dict(merge_text_path, default_p, default_lambda_val)
yaml_str = yaml.dump(yaml_dict, allow_unicode=True)
md = f"""``` yaml
{yaml_str}
```"""
return md
def upload_repo_list(filepath: str, default_p: float, default_lambda_val: float):
yaml_dict = {}
if Path(filepath).suffix in [".yml", ".yaml"]:
yaml_dict = load_yaml_dict(filepath)
if yaml_dict is not None:
with open(merge_yaml_path, mode='w', encoding='utf-8') as file:
yaml.dump(yaml_dict, file, default_flow_style=False, allow_unicode=True)
else:
yaml_dict = repo_text_to_yaml_dict(filepath, default_p, default_lambda_val)
shutil.copy(filepath, merge_text_path)
yaml_str = yaml.dump(yaml_dict, allow_unicode=True)
md = f"""``` yaml
{yaml_str}
```"""
return md
def clear_repo_list():
Path(merge_text_path).unlink(missing_ok=True)
Path(merge_yaml_path).unlink(missing_ok=True)
return gr.update(value=""), gr.update(value="")
def clear_output(output_dir: str):
shutil.rmtree(output_dir, ignore_errors=True)
print(f"Directory {output_dir} deleted successfully.")
def process_repos_gr(mode, p, lambda_val, skip_dirs: list[str], hf_user: str, hf_repo: str, hf_token: str,
is_upload=True, is_upload_sf=False, repo_exist_ok=False, files=[], repo_urls=[], progress=gr.Progress(track_tqdm=True)):
if is_upload and not hf_user:
print(f"Invalid user name: {hf_user}")
progress(1, desc=f"Invalid user name: {hf_user}")
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
if hf_token and not os.environ.get("HF_TOKEN"): os.environ['HF_TOKEN'] = hf_token
output_dir = "output"
base_model = "base_model"
staging_model = "staging_model"
output_model = str(Path(output_dir, base_model))
output_model = output_dir
repo_list_file = None
if is_upload:
clear_output(output_dir)
files = []
if Path(merge_yaml_path).exists(): repo_list_file = merge_yaml_path
elif Path(merge_text_path).exists(): repo_list_file = merge_text_path
if repo_list_file is None:
print("Repo list is not found.")
progress(1, desc="Repo list is not found.")
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
new_repo_id = f"{hf_user}/{hf_repo}"
if is_upload and not is_repo_name(new_repo_id):
print(f"Invalid Repo name: {new_repo_id}")
progress(1, desc=f"Invalid repo name: {new_repo_id}")
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
if is_upload and is_repo_exists(new_repo_id):
print(f"Repo already exists: {new_repo_id}")
if not repo_exist_ok:
progress(1, desc=f"Repo already exists: {new_repo_id}")
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
try:
progress(0, desc=f"Downloading Repos.")
if mode == "SDXL":
output_file_path, output_yaml_path = process_repos(output_dir, base_model, staging_model,
repo_list_file, p, lambda_val, skip_dirs + ["text_encoder"], False)
else:
output_file_path, output_yaml_path = process_repos(output_dir, base_model, staging_model,
repo_list_file, p, lambda_val, skip_dirs, False)
if mode == "Single files":
files.append(output_file_path)
files.append(output_yaml_path)
except Exception as e:
print(e)
progress(1, desc=f"Error occured: {e}")
repo_url = None
if Path(output_model).exists():
if mode != "Single files": save_readme_md(output_model, repo_list_file, p, lambda_val)
if is_upload_sf:
if mode == "SDXL": files.append(convert_output_to_safetensors(output_model, hf_repo))
elif mode == "SD1.5": files.append(convert_output_to_safetensors_sd(output_model, hf_repo))
if is_upload:
if not is_repo_exists(new_repo_id): create_repo(new_repo_id)
repo_url = upload_dir_to_repo(new_repo_id, output_model)
else:
progress(1, desc=f"Merging failed.")
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
if not repo_urls: repo_urls = []
if repo_url: repo_urls.append(repo_url)
md = "Your new Repo:<br>"
for u in repo_urls:
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md)
from convert_repo_to_safetensors_gr import convert_diffusers_to_safetensors
def convert_output_to_safetensors(output_dir: str, repo_name: str, progress=gr.Progress(track_tqdm=True)):
output_filename = f"{repo_name}.safetensors"
convert_diffusers_to_safetensors(output_dir, Path(output_dir, output_filename))
return output_filename
from convert_repo_to_safetensors_sd_gr import convert_diffusers_to_safetensors as convert_diffusers_to_safetensors_sd
def convert_output_to_safetensors_sd(output_dir: str, repo_name: str, progress=gr.Progress(track_tqdm=True)):
output_filename = f"{repo_name}.safetensors"
convert_diffusers_to_safetensors_sd(output_dir, Path(output_dir, output_filename))
return output_filename
def upload_repo_list(filepath: str, default_p: float, default_lambda_val: float):
yaml_dict = {}
if Path(filepath).suffix in [".yml", ".yaml"]:
yaml_dict = load_yaml_dict(filepath)
if yaml_dict is not None:
with open(merge_yaml_path, mode='w', encoding='utf-8') as file:
yaml.dump(yaml_dict, file, default_flow_style=False, allow_unicode=True)
else:
yaml_dict = repo_text_to_yaml_dict(filepath, default_p, default_lambda_val)
shutil.copy(filepath, merge_text_path)
yaml_str = yaml.dump(yaml_dict, allow_unicode=True)
md = f"""``` yaml
{yaml_str}
```"""
return md
def save_readme_md(dir: str, yaml_path:str, default_p: float, default_lambda_val: float):
yaml_dict = {}
if Path(yaml_path).suffix in [".yml", ".yaml"]:
yaml_dict = load_yaml_dict(yaml_path)
else:
yaml_dict = repo_text_to_yaml_dict(yaml_path, default_p, default_lambda_val)
yaml_str = yaml.dump(yaml_dict, allow_unicode=True)
md = f"""---
license: other
language:
- en
library_name: diffusers
pipeline_tag: text-to-image
tags:
- text-to-image
---
<br>Merged model.<br>
## 🧩 Configuration
``` yaml
{yaml_str}
```"""
path = str(Path(dir, "README.md"))
with open(path, mode='w', encoding="utf-8") as f:
f.write(md)