Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import tempfile | |
| import shutil | |
| import re | |
| import json | |
| import datetime | |
| from pathlib import Path | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from safetensors.torch import load_file, save_file | |
| import torch | |
| # Optional ModelScope integration | |
| try: | |
| from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download | |
| from modelscope.hub.file_download import model_file_download as ms_file_download | |
| from modelscope.hub.api import HubApi as ModelScopeApi | |
| MODELScope_AVAILABLE = True | |
| except ImportError: | |
| MODELScope_AVAILABLE = False | |
| # --- Conversion Function: Safetensors β FP8 Safetensors --- | |
| def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress=gr.Progress()): | |
| progress(0.1, desc="Starting FP8 conversion...") | |
| try: | |
| def read_safetensors_metadata(path): | |
| with open(path, 'rb') as f: | |
| header_size = int.from_bytes(f.read(8), 'little') | |
| header_json = f.read(header_size).decode('utf-8') | |
| header = json.loads(header_json) | |
| return header.get('__metadata__', {}) | |
| metadata = read_safetensors_metadata(safetensors_path) | |
| progress(0.3, desc="Loaded model metadata.") | |
| state_dict = load_file(safetensors_path) | |
| progress(0.5, desc="Loaded model weights.") | |
| if fp8_format == "e5m2": | |
| fp8_dtype = torch.float8_e5m2 | |
| else: | |
| fp8_dtype = torch.float8_e4m3fn | |
| sd_pruned = {} | |
| total = len(state_dict) | |
| for i, key in enumerate(state_dict): | |
| progress(0.5 + 0.4 * (i / total), desc=f"Converting tensor {i+1}/{total} to FP8 ({fp8_format})...") | |
| if state_dict[key].dtype in [torch.float16, torch.float32, torch.bfloat16]: | |
| sd_pruned[key] = state_dict[key].to(fp8_dtype) | |
| else: | |
| sd_pruned[key] = state_dict[key] | |
| base_name = os.path.splitext(os.path.basename(safetensors_path))[0] | |
| output_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors") | |
| save_file(sd_pruned, output_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata}) | |
| progress(0.9, desc="Saved FP8 safetensors file.") | |
| progress(1.0, desc="FP8 conversion complete!") | |
| return True, f"Model successfully pruned to FP8 ({fp8_format})." | |
| except Exception as e: | |
| return False, str(e) | |
| # --- Source download helper --- | |
| def download_safetensors_file( | |
| source_type, | |
| repo_url, | |
| filename, | |
| hf_token=None, | |
| modelscope_token=None, | |
| progress=gr.Progress() | |
| ): | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| if source_type == "huggingface": | |
| clean_url = repo_url.strip().rstrip("/") | |
| if "huggingface.co" not in clean_url: | |
| raise ValueError("Invalid Hugging Face URL") | |
| src_repo_id = clean_url.replace("https://huggingface.co/", "") | |
| safetensors_path = hf_hub_download( | |
| repo_id=src_repo_id, | |
| filename=filename, | |
| cache_dir=temp_dir, | |
| token=hf_token | |
| ) | |
| elif source_type == "modelscope": | |
| if not MODELScope_AVAILABLE: | |
| raise ImportError("ModelScope not installed. Install with: pip install modelscope") | |
| clean_url = repo_url.strip().rstrip("/") | |
| if "modelscope.cn" in clean_url: | |
| src_repo_id = "/".join(clean_url.split("/")[-2:]) | |
| else: | |
| src_repo_id = repo_url.strip() | |
| if modelscope_token: | |
| os.environ["MODELSCOPE_CACHE"] = temp_dir | |
| safetensors_path = ms_file_download( | |
| model_id=src_repo_id, | |
| file_path=filename, | |
| token=modelscope_token | |
| ) | |
| else: | |
| safetensors_path = ms_file_download( | |
| model_id=src_repo_id, | |
| file_path=filename | |
| ) | |
| else: | |
| raise ValueError("Unknown source type") | |
| return safetensors_path, temp_dir | |
| except Exception as e: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| raise e | |
| # --- Upload helper --- | |
| def upload_to_target( | |
| target_type, | |
| new_repo_id, | |
| output_dir, | |
| fp8_format, | |
| hf_token=None, | |
| modelscope_token=None, | |
| private_repo=False, | |
| progress=gr.Progress() | |
| ): | |
| if target_type == "huggingface": | |
| if not hf_token: | |
| raise ValueError("Hugging Face token required") | |
| api = HfApi(token=hf_token) | |
| api.create_repo( | |
| repo_id=new_repo_id, | |
| private=private_repo, | |
| repo_type="model", | |
| exist_ok=True | |
| ) | |
| api.upload_folder( | |
| repo_id=new_repo_id, | |
| folder_path=output_dir, | |
| repo_type="model", | |
| token=hf_token, | |
| commit_message=f"Upload FP8 ({fp8_format}) model" | |
| ) | |
| return f"https://huggingface.co/{new_repo_id}" | |
| elif target_type == "modelscope": | |
| if not MODELScope_AVAILABLE: | |
| raise ImportError("ModelScope not installed") | |
| api = ModelScopeApi() | |
| if modelscope_token: | |
| api.login(modelscope_token) | |
| # ModelScope requires model_type and license | |
| api.push_model( | |
| model_id=new_repo_id, | |
| model_dir=output_dir, | |
| commit_message=f"Upload FP8 ({fp8_format}) model" | |
| ) | |
| return f"https://modelscope.cn/models/{new_repo_id}" | |
| else: | |
| raise ValueError("Unknown target type") | |
| # --- Main Processing Function --- | |
| def process_and_upload_fp8( | |
| source_type, | |
| repo_url, | |
| safetensors_filename, | |
| fp8_format, | |
| target_type, | |
| new_repo_id, | |
| hf_token, | |
| modelscope_token, | |
| private_repo, | |
| progress=gr.Progress() | |
| ): | |
| required_fields = [repo_url, safetensors_filename, new_repo_id] | |
| if source_type == "huggingface": | |
| required_fields.append(hf_token) | |
| if target_type == "huggingface": | |
| required_fields.append(hf_token) | |
| if target_type == "modelscope" and modelscope_token: | |
| required_fields.append(modelscope_token) | |
| if not all(required_fields): | |
| return None, "β Error: Please fill in all required fields.", "" | |
| if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id): | |
| return None, "β Invalid repository ID format. Use 'username/model-name'.", "" | |
| temp_dir = None | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| # Authenticate & download | |
| progress(0.05, desc="Authenticating and downloading...") | |
| safetensors_path, temp_dir = download_safetensors_file( | |
| source_type=source_type, | |
| repo_url=repo_url, | |
| filename=safetensors_filename, | |
| hf_token=hf_token, | |
| modelscope_token=modelscope_token, | |
| progress=progress | |
| ) | |
| progress(0.25, desc="Download complete.") | |
| # Convert | |
| success, msg = convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress) | |
| if not success: | |
| return None, f"β Conversion failed: {msg}", "" | |
| # Upload | |
| progress(0.92, desc="Uploading model...") | |
| repo_url_final = upload_to_target( | |
| target_type=target_type, | |
| new_repo_id=new_repo_id, | |
| output_dir=output_dir, | |
| fp8_format=fp8_format, | |
| hf_token=hf_token, | |
| modelscope_token=modelscope_token, | |
| private_repo=private_repo, | |
| progress=progress | |
| ) | |
| # README | |
| base_name = os.path.splitext(safetensors_filename)[0] | |
| fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors" | |
| readme = f"""--- | |
| library_name: diffusers | |
| tags: | |
| - fp8 | |
| - safetensors | |
| - pruned | |
| - diffusion | |
| - converted-by-gradio | |
| - fp8-{fp8_format} | |
| --- | |
| # FP8 Pruned Model ({fp8_format.upper()}) | |
| Converted from: `{repo_url}` | |
| File: `{safetensors_filename}` β `{fp8_filename}` | |
| Quantization: **FP8 ({fp8_format.upper()})** | |
| Converted on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
| > β οΈ Requires PyTorch β₯ 2.1 and compatible hardware for FP8 acceleration. | |
| """ | |
| readme_path = os.path.join(output_dir, "README.md") | |
| with open(readme_path, "w") as f: | |
| f.write(readme) | |
| # Re-upload README if needed (for ModelScope, already included; for HF, upload separately) | |
| if target_type == "huggingface": | |
| HfApi(token=hf_token).upload_file( | |
| path_or_fileobj=readme_path, | |
| path_in_repo="README.md", | |
| repo_id=new_repo_id, | |
| repo_type="model", | |
| token=hf_token | |
| ) | |
| progress(1.0, desc="β Done!") | |
| result_html = f""" | |
| β Success! | |
| Your FP8 model is uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a> | |
| Source: {source_type.title()} β Target: {target_type.title()} | |
| """ | |
| return gr.HTML(result_html), "β FP8 conversion and upload successful!", "" | |
| except Exception as e: | |
| return None, f"β Error: {str(e)}", "" | |
| finally: | |
| if temp_dir: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| shutil.rmtree(output_dir, ignore_errors=True) | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="Safetensors β FP8 Pruner (HF + ModelScope)") as demo: | |
| gr.Markdown("# π Safetensors to FP8 Pruner") | |
| gr.Markdown("Convert `.safetensors` models to **FP8** and upload to **Hugging Face** or **ModelScope**.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_type = gr.Radio( | |
| choices=["huggingface", "modelscope"], | |
| value="huggingface", | |
| label="Source Platform" | |
| ) | |
| repo_url = gr.Textbox( | |
| label="Source Repository URL", | |
| placeholder="e.g., https://huggingface.co/Yabo/FramePainter OR your-modelscope-id", | |
| info="Hugging Face URL or ModelScope model ID" | |
| ) | |
| safetensors_filename = gr.Textbox( | |
| label="Safetensors Filename", | |
| placeholder="unet_diffusion_pytorch_model.safetensors" | |
| ) | |
| fp8_format = gr.Radio( | |
| choices=["e4m3fn", "e5m2"], | |
| value="e5m2", | |
| label="FP8 Format", | |
| info="E5M2: wider range; E4M3FN: better near-zero precision" | |
| ) | |
| hf_token = gr.Textbox( | |
| label="Hugging Face Token (if using HF)", | |
| type="password" | |
| ) | |
| modelscope_token = gr.Textbox( | |
| label="ModelScope Token (optional)", | |
| type="password", | |
| visible=MODELScope_AVAILABLE | |
| ) | |
| with gr.Column(): | |
| target_type = gr.Radio( | |
| choices=["huggingface", "modelscope"], | |
| value="huggingface", | |
| label="Target Platform" | |
| ) | |
| new_repo_id = gr.Textbox( | |
| label="New Repository ID", | |
| placeholder="your-username/my-model-fp8" | |
| ) | |
| private_repo = gr.Checkbox(label="Make Private (HF only)", value=False) | |
| convert_btn = gr.Button("π Convert & Upload", variant="primary") | |
| with gr.Row(): | |
| status_output = gr.Markdown() | |
| repo_link_output = gr.HTML() | |
| convert_btn.click( | |
| fn=process_and_upload_fp8, | |
| inputs=[ | |
| source_type, | |
| repo_url, | |
| safetensors_filename, | |
| fp8_format, | |
| target_type, | |
| new_repo_id, | |
| hf_token, | |
| modelscope_token, | |
| private_repo | |
| ], | |
| outputs=[repo_link_output, status_output], | |
| show_progress=True | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["huggingface", "https://huggingface.co/Yabo/FramePainter", "unet_diffusion_pytorch_model.safetensors", "e5m2", "huggingface"] | |
| ], | |
| inputs=[source_type, repo_url, safetensors_filename, fp8_format, target_type] | |
| ) | |
| demo.launch() |