Spaces:
Running
Running
from pathlib import Path | |
import os | |
import shutil | |
import yaml | |
import gradio as gr | |
from hf_merge import process_repos, repo_list_generator | |
def list_sub(a, b): | |
return [e for e in a if e not in b] | |
def is_repo_name(s): | |
import re | |
return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s) | |
def is_valid_model_name(s): | |
if is_repo_name(s) or Path(s).suffix in (".safetensors", ".bin", ".sft"): return True | |
else: return False | |
def is_repo_exists(repo_id): | |
from huggingface_hub import HfApi | |
api = HfApi() | |
try: | |
if api.repo_exists(repo_id=repo_id): return True | |
else: return False | |
except Exception as e: | |
print(f"Error: Failed to connect {repo_id}. ") | |
return True # for safe | |
def create_repo(new_repo_id): | |
from huggingface_hub import HfApi | |
import os | |
hf_token = os.environ.get("HF_TOKEN") | |
api = HfApi() | |
try: | |
api.create_repo(repo_id=new_repo_id, token=hf_token, private=True) | |
url = f"https://huggingface.co/{new_repo_id}" | |
except Exception as e: | |
print(f"Error: Failed to create {new_repo_id}. ") | |
print(e) | |
return "" | |
return url | |
def upload_dir_to_repo(new_repo_id, folder, progress=gr.Progress(track_tqdm=True)): | |
from huggingface_hub import HfApi | |
import os | |
hf_token = os.environ.get("HF_TOKEN") | |
api = HfApi() | |
try: | |
progress(0, desc="Start uploading...") | |
for path in Path(folder).glob("*"): | |
if path.is_dir(): | |
api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token) | |
elif path.is_file(): | |
api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token) | |
progress(1, desc="Uploaded.") | |
url = f"https://huggingface.co/{new_repo_id}" | |
except Exception as e: | |
print(f"Error: Failed to upload to {new_repo_id}. ") | |
print(e) | |
return "" | |
return url | |
merge_yaml_path = "./merge_yaml.yaml" | |
merge_text_path = "./merge_txt.txt" | |
def load_yaml_dict(yaml_path: str): | |
yaml_dict = None | |
try: | |
data = None | |
with open(yaml_path, 'r', encoding='utf-8') as file: | |
data = yaml.safe_load(file) | |
except Exception as e: | |
print(e) | |
data = None | |
if isinstance(data, dict) and 'models' in data.keys() and data['models']: | |
yaml_dict = data | |
return yaml_dict | |
def repo_text_to_yaml_dict(text_path: str, default_p: float, default_lambda_val: float): | |
yaml_dict = {} | |
repos = list(repo_list_generator(text_path, default_p, default_lambda_val)) | |
yaml_dict.setdefault('models', {}) | |
for repo in repos: | |
model, weight, density = repo | |
if not is_valid_model_name(model): continue | |
model_info = {} | |
model_info['model'] = str(model) | |
model_info.setdefault('parameters', {}) | |
model_info['parameters']['weight'] = float(weight) | |
model_info['parameters']['density'] = float(density) | |
yaml_dict['models'][str(model.split("/")[-1])] = model_info | |
return yaml_dict | |
def gen_repo_list(input_text: str, default_p: float, default_lambda_val: float): | |
yaml_dict = {} | |
if Path(merge_yaml_path).exists(): | |
yaml_dict = load_yaml_dict(merge_yaml_path) | |
else: | |
with open(merge_text_path, mode='w', encoding='utf-8') as file: | |
file.write(input_text) | |
yaml_dict = repo_text_to_yaml_dict(merge_text_path, default_p, default_lambda_val) | |
yaml_str = yaml.dump(yaml_dict, allow_unicode=True) | |
md = f"""``` yaml | |
{yaml_str} | |
```""" | |
return md | |
def upload_repo_list(filepath: str, default_p: float, default_lambda_val: float): | |
yaml_dict = {} | |
if Path(filepath).suffix in [".yml", ".yaml"]: | |
yaml_dict = load_yaml_dict(filepath) | |
if yaml_dict is not None: | |
with open(merge_yaml_path, mode='w', encoding='utf-8') as file: | |
yaml.dump(yaml_dict, file, default_flow_style=False, allow_unicode=True) | |
else: | |
yaml_dict = repo_text_to_yaml_dict(filepath, default_p, default_lambda_val) | |
shutil.copy(filepath, merge_text_path) | |
yaml_str = yaml.dump(yaml_dict, allow_unicode=True) | |
md = f"""``` yaml | |
{yaml_str} | |
```""" | |
return md | |
def clear_repo_list(): | |
Path(merge_text_path).unlink(missing_ok=True) | |
Path(merge_yaml_path).unlink(missing_ok=True) | |
return gr.update(value=""), gr.update(value="") | |
def clear_output(output_dir: str): | |
shutil.rmtree(output_dir, ignore_errors=True) | |
print(f"Directory {output_dir} deleted successfully.") | |
def process_repos_gr(mode, p, lambda_val, skip_dirs: list[str], hf_user: str, hf_repo: str, hf_token: str, | |
is_upload=True, is_upload_sf=False, repo_exist_ok=False, files=[], repo_urls=[], progress=gr.Progress(track_tqdm=True)): | |
if is_upload and not hf_user: | |
print(f"Invalid user name: {hf_user}") | |
progress(1, desc=f"Invalid user name: {hf_user}") | |
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True) | |
if hf_token and not os.environ.get("HF_TOKEN"): os.environ['HF_TOKEN'] = hf_token | |
output_dir = "output" | |
base_model = "base_model" | |
staging_model = "staging_model" | |
output_model = str(Path(output_dir, base_model)) | |
output_model = output_dir | |
repo_list_file = None | |
if is_upload: | |
clear_output(output_dir) | |
files = [] | |
if Path(merge_yaml_path).exists(): repo_list_file = merge_yaml_path | |
elif Path(merge_text_path).exists(): repo_list_file = merge_text_path | |
if repo_list_file is None: | |
print("Repo list is not found.") | |
progress(1, desc="Repo list is not found.") | |
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True) | |
new_repo_id = f"{hf_user}/{hf_repo}" | |
if is_upload and not is_repo_name(new_repo_id): | |
print(f"Invalid Repo name: {new_repo_id}") | |
progress(1, desc=f"Invalid repo name: {new_repo_id}") | |
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True) | |
if is_upload and is_repo_exists(new_repo_id): | |
print(f"Repo already exists: {new_repo_id}") | |
if not repo_exist_ok: | |
progress(1, desc=f"Repo already exists: {new_repo_id}") | |
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True) | |
try: | |
progress(0, desc=f"Downloading Repos.") | |
if mode == "SDXL": | |
output_file_path, output_yaml_path = process_repos(output_dir, base_model, staging_model, | |
repo_list_file, p, lambda_val, skip_dirs + ["text_encoder"], False) | |
else: | |
output_file_path, output_yaml_path = process_repos(output_dir, base_model, staging_model, | |
repo_list_file, p, lambda_val, skip_dirs, False) | |
if mode == "Single files": | |
files.append(output_file_path) | |
files.append(output_yaml_path) | |
except Exception as e: | |
print(e) | |
progress(1, desc=f"Error occured: {e}") | |
repo_url = None | |
if Path(output_model).exists(): | |
if mode != "Single files": save_readme_md(output_model, repo_list_file, p, lambda_val) | |
if is_upload_sf: | |
if mode == "SDXL": files.append(convert_output_to_safetensors(output_model, hf_repo)) | |
elif mode == "SD1.5": files.append(convert_output_to_safetensors_sd(output_model, hf_repo)) | |
if is_upload: | |
if not is_repo_exists(new_repo_id): create_repo(new_repo_id) | |
repo_url = upload_dir_to_repo(new_repo_id, output_model) | |
else: | |
progress(1, desc=f"Merging failed.") | |
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True) | |
if not repo_urls: repo_urls = [] | |
if repo_url: repo_urls.append(repo_url) | |
md = "Your new Repo:<br>" | |
for u in repo_urls: | |
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>" | |
return gr.update(value=files), gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md) | |
from convert_repo_to_safetensors_gr import convert_diffusers_to_safetensors | |
def convert_output_to_safetensors(output_dir: str, repo_name: str, progress=gr.Progress(track_tqdm=True)): | |
output_filename = f"{repo_name}.safetensors" | |
convert_diffusers_to_safetensors(output_dir, Path(output_dir, output_filename)) | |
return output_filename | |
from convert_repo_to_safetensors_sd_gr import convert_diffusers_to_safetensors as convert_diffusers_to_safetensors_sd | |
def convert_output_to_safetensors_sd(output_dir: str, repo_name: str, progress=gr.Progress(track_tqdm=True)): | |
output_filename = f"{repo_name}.safetensors" | |
convert_diffusers_to_safetensors_sd(output_dir, Path(output_dir, output_filename)) | |
return output_filename | |
def upload_repo_list(filepath: str, default_p: float, default_lambda_val: float): | |
yaml_dict = {} | |
if Path(filepath).suffix in [".yml", ".yaml"]: | |
yaml_dict = load_yaml_dict(filepath) | |
if yaml_dict is not None: | |
with open(merge_yaml_path, mode='w', encoding='utf-8') as file: | |
yaml.dump(yaml_dict, file, default_flow_style=False, allow_unicode=True) | |
else: | |
yaml_dict = repo_text_to_yaml_dict(filepath, default_p, default_lambda_val) | |
shutil.copy(filepath, merge_text_path) | |
yaml_str = yaml.dump(yaml_dict, allow_unicode=True) | |
md = f"""``` yaml | |
{yaml_str} | |
```""" | |
return md | |
def save_readme_md(dir: str, yaml_path:str, default_p: float, default_lambda_val: float): | |
yaml_dict = {} | |
if Path(yaml_path).suffix in [".yml", ".yaml"]: | |
yaml_dict = load_yaml_dict(yaml_path) | |
else: | |
yaml_dict = repo_text_to_yaml_dict(yaml_path, default_p, default_lambda_val) | |
yaml_str = yaml.dump(yaml_dict, allow_unicode=True) | |
md = f"""--- | |
license: other | |
language: | |
- en | |
library_name: diffusers | |
pipeline_tag: text-to-image | |
tags: | |
- text-to-image | |
--- | |
<br>Merged model.<br> | |
## 🧩 Configuration | |
``` yaml | |
{yaml_str} | |
```""" | |
path = str(Path(dir, "README.md")) | |
with open(path, mode='w', encoding="utf-8") as f: | |
f.write(md) | |