rootlocalghost commited on
Commit
1c3237a
·
verified ·
1 Parent(s): abe82bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -230
app.py CHANGED
@@ -1,260 +1,141 @@
1
- import gradio as gr
2
- from huggingface_hub import HfApi, hf_hub_download, BucketFile, BucketFolder
3
- from safetensors.torch import load_file, save_file
4
  import os
5
- import re
6
- import time
7
- import shutil
8
- import tempfile
9
- import torch
10
  import gc
 
 
 
 
 
11
 
12
- RATE_LIMIT_PATTERN = re.compile(r"Retry after\s*(\d+)\s*seconds", re.IGNORECASE)
 
 
13
 
14
- def _is_rate_limit_error(exc: Exception) -> bool:
15
- message = str(exc).lower()
16
- return "429" in message or "too many requests" in message or "rate limit" in message
 
17
 
18
- def _get_retry_after(exc: Exception, default: int = 2) -> int:
19
- match = RATE_LIMIT_PATTERN.search(str(exc))
20
- if match:
21
- return int(match.group(1))
22
- return default
23
 
24
- def _retry_api_call(fn, *args, retries: int = 3, **kwargs):
25
- delay = 2
26
- for attempt in range(1, retries + 1):
27
- try:
28
- return fn(*args, **kwargs)
29
- except Exception as exc:
30
- if not _is_rate_limit_error(exc) or attempt == retries:
31
- raise
32
- wait = _get_retry_after(exc, delay)
33
- time.sleep(wait)
34
- delay = min(delay * 2, 60)
35
 
36
- def _format_bucket_uri(bucket_id: str) -> str:
37
- bucket_id = bucket_id.strip()
38
- if bucket_id.startswith("hf://buckets/"):
39
- return bucket_id
40
- if bucket_id.startswith("buckets/"):
41
- return f"hf://{bucket_id}"
42
- return f"hf://buckets/{bucket_id}"
43
 
44
- def _get_target_dtype(precision):
45
- if precision == "FP8":
46
- return torch.float8_e4m3fn
47
- elif precision == "FP16":
48
- return torch.float16
49
- elif precision == "BF16":
50
- return torch.bfloat16
51
- return None
52
 
53
- def _process_and_upload_file(local_path, file_path, target_repo, repo_type, api, hf_token, precision, target_folders):
54
- target_dtype = _get_target_dtype(precision)
55
- folders_to_check = [f.strip() for f in target_folders.split(",")] if target_folders else []
56
-
57
- # Check if this file is in a targeted folder (or if no target folders are specified)
58
- in_target_folder = any(f in file_path for f in folders_to_check) if folders_to_check else True
59
 
60
- if precision != "None" and file_path.endswith(".safetensors") and in_target_folder:
61
- # Quantize the file
62
- tensors = load_file(local_path)
63
- for k, v in tensors.items():
64
- if v.is_floating_point():
65
- tensors[k] = v.to(target_dtype)
66
-
67
- converted_path = local_path + ".converted"
68
- save_file(tensors, converted_path)
69
-
70
- # Flush memory
71
- del tensors
72
- gc.collect()
 
 
 
73
 
74
- api.upload_file(
75
- path_or_fileobj=converted_path,
76
- path_in_repo=file_path,
77
- repo_id=target_repo,
78
- repo_type=repo_type,
79
- commit_message=f"clone & quantize ({precision}) {file_path}",
80
- token=hf_token,
81
- )
82
- os.remove(converted_path)
83
- else:
84
- # Upload as-is
85
- api.upload_file(
86
- path_or_fileobj=local_path,
87
- path_in_repo=file_path,
88
- repo_id=target_repo,
89
- repo_type=repo_type,
90
- commit_message=f"clone {file_path}",
91
- token=hf_token,
92
- )
93
 
94
- def _stream_clone_repo(source_repo, target_repo, repo_type, api, hf_token, precision, target_folders):
95
- file_paths = api.list_repo_files(
96
- repo_id=source_repo,
97
- repo_type=repo_type,
98
- token=hf_token,
99
- )
100
- if not file_paths:
101
- raise ValueError("source repo is empty or could not be listed")
102
 
103
- with tempfile.TemporaryDirectory(prefix="hf_file_") as root_dir:
104
- for file_path in file_paths:
105
- if file_path.endswith("/"):
106
- continue
107
- try:
108
- downloaded_path = hf_hub_download(
109
- repo_id=source_repo,
110
- filename=file_path,
111
- repo_type=repo_type,
112
- local_dir=root_dir,
113
- local_dir_use_symlinks=False,
114
- token=hf_token,
115
  )
116
- if not os.path.isfile(downloaded_path):
117
- raise ValueError(f"Downloaded file not found: {downloaded_path}")
118
 
119
- _process_and_upload_file(downloaded_path, file_path, target_repo, repo_type, api, hf_token, precision, target_folders)
120
-
121
- finally:
122
- if os.path.exists(downloaded_path):
123
- os.remove(downloaded_path)
124
- gc.collect()
125
-
126
- def _upload_local_source(source_path, target_repo, repo_type, api, hf_token, precision, target_folders):
127
- if not os.path.isdir(source_path):
128
- raise ValueError("Local source path must be an existing directory.")
129
-
130
- if precision == "None":
131
- # Bulk upload if no quantization is needed
132
- api.upload_large_folder(
133
- repo_id=target_repo,
134
- folder_path=source_path,
135
- repo_type=repo_type,
136
- num_workers=1,
137
- print_report=False,
138
- )
139
- else:
140
- # File-by-file processing for local quantization
141
- for root, _, files in os.walk(source_path):
142
- for file in files:
143
- local_file_path = os.path.join(root, file)
144
- repo_file_path = os.path.relpath(local_file_path, source_path).replace("\\", "/")
145
- _process_and_upload_file(local_file_path, repo_file_path, target_repo, repo_type, api, hf_token, precision, target_folders)
146
- gc.collect()
147
 
148
- def _stream_clone_bucket(source_repo, target_repo, repo_type, api, hf_token, precision, target_folders):
149
- bucket_uri = _format_bucket_uri(source_repo)
150
- bucket_id = bucket_uri[len("hf://"):]
151
- items = api.list_bucket_tree(bucket_id=bucket_id, recursive=True, token=hf_token)
152
-
153
- with tempfile.TemporaryDirectory(prefix="hf_file_") as root_dir:
154
- for item in items:
155
- if isinstance(item, BucketFolder):
156
- continue
157
- if isinstance(item, BucketFile):
158
- local_path = os.path.join(root_dir, item.path)
159
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
160
- try:
161
- api.download_bucket_files(
162
- bucket_id=bucket_id,
163
- files=[(item.path, local_path)],
164
- token=hf_token,
165
- )
166
- _process_and_upload_file(local_path, item.path, target_repo, repo_type, api, hf_token, precision, target_folders)
167
- finally:
168
- if os.path.exists(local_path):
169
- os.remove(local_path)
170
  gc.collect()
171
 
172
- def stealth_clone_hf_repo(hf_token_ui, source_repo, source_type, target_repo, repo_type, precision, target_folders):
173
- # Use UI token first, fallback to environment variable
174
- hf_token = hf_token_ui.strip() if hf_token_ui.strip() else os.environ.get("HF_TOKEN")
175
-
176
- if not hf_token:
177
- return "error: HF_TOKEN secret not found and no token provided in the UI."
178
 
179
- api = HfApi(token=hf_token)
180
- try:
181
- _retry_api_call(
182
- api.create_repo,
183
- repo_id=target_repo,
184
- repo_type=repo_type,
185
- exist_ok=True,
186
- )
187
-
188
- if source_type == "bucket":
189
- _stream_clone_bucket(source_repo, target_repo, repo_type, api, hf_token, precision, target_folders)
190
- elif source_type == "local":
191
- _upload_local_source(source_repo, target_repo, repo_type, api, hf_token, precision, target_folders)
192
- else:
193
- _stream_clone_repo(source_repo, target_repo, repo_type, api, hf_token, precision, target_folders)
194
 
195
- return f"success! cleanly cloned and processed {source_repo} to {target_repo} with no tags."
196
- except Exception as e:
197
- return f"error: {type(e).__name__}: {str(e)}"
198
 
199
- # Build UI
200
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
201
- gr.Markdown("## 🥷 Stealth Quantizing Cloner")
202
  gr.Markdown(
203
- "Clone repos, buckets, or local paths cleanly without the 'duplicated from' tag. "
204
- "Optionally intercept `.safetensors` files during the transfer and quantize them on the fly."
 
 
205
  )
206
-
207
- with gr.Row():
208
- hf_token_input = gr.Textbox(
209
- label="Hugging Face Token (Write Access)",
210
- type="password",
211
- placeholder="hf_... (Leave blank if using Space Secrets)"
212
- )
213
-
214
- with gr.Row():
215
- source_input = gr.Textbox(
216
- label="Source (Repo ID, Bucket ID, or Local Path)",
217
- placeholder="e.g. user/model, username/my-bucket, or /mnt/bucket"
218
- )
219
- source_type_input = gr.Radio(
220
- choices=["repo", "bucket", "local"],
221
- value="repo",
222
- label="Source Type"
223
- )
224
-
225
- with gr.Row():
226
- target_input = gr.Textbox(
227
- label="Target Repo ID",
228
- placeholder="e.g. your-username/new-model"
229
- )
230
- repo_type_input = gr.Radio(
231
- choices=["model", "dataset", "space"],
232
- value="model",
233
- label="Target Repository Type"
234
- )
235
 
236
  with gr.Row():
237
- precision_input = gr.Dropdown(
238
- choices=["None", "FP8", "FP16", "BF16"],
239
- value="None",
240
- label="On-the-Fly Quantization"
241
- )
242
- target_folders_input = gr.Textbox(
243
- label="Target Folders for Quantization (Comma separated)",
244
- placeholder="e.g. text_encoder, transformer",
245
- value="text_encoder, transformer"
246
- )
247
 
248
- clone_btn = gr.Button("Stealth Clone & Process", variant="primary")
249
- output = gr.Textbox(label="Status", lines=3)
250
-
251
- clone_btn.click(
252
- fn=stealth_clone_hf_repo,
253
- inputs=[
254
- hf_token_input, source_input, source_type_input, target_input,
255
- repo_type_input, precision_input, target_folders_input
256
- ],
257
- outputs=output
 
 
258
  )
259
 
260
  if __name__ == "__main__":
 
 
 
 
1
  import os
 
 
 
 
 
2
  import gc
3
+ import torch
4
+ import shutil
5
+ import gradio as gr
6
+ from huggingface_hub import HfApi, hf_hub_download
7
+ from safetensors.torch import load_file, save_file
8
 
9
+ SOURCE_REPO = "Tongyi-MAI/Z-Image-Turbo"
10
+ TARGET_REPO = "rootlocalghost/Z-Image-Turbo-FP8"
11
+ TEMP_DIR = "temp_processing_dir"
12
 
13
+ def convert_and_upload(token):
14
+ if not token:
15
+ yield " Error: Please provide a valid Hugging Face Write Token."
16
+ return
17
 
18
+ api = HfApi(token=token)
19
+ yield f"🔄 Connecting to Hugging Face and verifying target repo: {TARGET_REPO}..."
 
 
 
20
 
21
+ # Ensure the target repo exists, create it if it doesn't
22
+ try:
23
+ api.create_repo(repo_id=TARGET_REPO, exist_ok=True, private=False)
24
+ except Exception as e:
25
+ yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
26
+ return
 
 
 
 
 
27
 
28
+ yield "📋 Fetching file list from the source repository..."
29
+ try:
30
+ files = api.list_repo_files(SOURCE_REPO)
31
+ except Exception as e:
32
+ yield f"❌ Error fetching files: {str(e)}"
33
+ return
 
34
 
35
+ # Create a temporary directory for safe local processing
36
+ os.makedirs(TEMP_DIR, exist_ok=True)
 
 
 
 
 
 
37
 
38
+ for file in files:
39
+ yield f"⏳ Processing {file}..."
 
 
 
 
40
 
41
+ try:
42
+ # Download file locally without using the central symlink cache
43
+ # This is critical to prevent the 50GB Space disk from filling up
44
+ local_path = hf_hub_download(
45
+ repo_id=SOURCE_REPO,
46
+ filename=file,
47
+ local_dir=TEMP_DIR,
48
+ local_dir_use_symlinks=False
49
+ )
50
+
51
+ # Check if it's a safetensor file inside the target directories
52
+ if file.endswith(".safetensors") and ("text_encoder/" in file or "transformer/" in file):
53
+ yield f"🧠 Quantizing {file} to FP8 (This may take a minute)..."
54
+
55
+ # Load tensors into RAM
56
+ tensors = load_file(local_path)
57
 
58
+ # Cast all floating point tensors to FP8
59
+ keys = list(tensors.keys())
60
+ for k in keys:
61
+ if tensors[k].is_floating_point():
62
+ tensors[k] = tensors[k].to(torch.float8_e4m3fn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Save the quantized tensors to a new temp file
65
+ converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
66
+ save_file(tensors, converted_path)
67
+
68
+ # Wipe the tensors from RAM immediately to stay under the 16GB limit
69
+ del tensors
70
+ gc.collect()
 
71
 
72
+ yield f"☁️ Uploading FP8 version of {file}..."
73
+ api.upload_file(
74
+ path_or_fileobj=converted_path,
75
+ path_in_repo=file,
76
+ repo_id=TARGET_REPO,
77
+ commit_message=f"Upload FP8 quantized {file}"
 
 
 
 
 
 
78
  )
 
 
79
 
80
+ # Clean up the converted file
81
+ os.remove(converted_path)
82
+
83
+ else:
84
+ yield f"☁️ Copying {file} as-is..."
85
+ api.upload_file(
86
+ path_or_fileobj=local_path,
87
+ path_in_repo=file,
88
+ repo_id=TARGET_REPO,
89
+ commit_message=f"Copy {file} from original repo"
90
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Delete the downloaded original file to free up disk space
93
+ if os.path.exists(local_path):
94
+ os.remove(local_path)
95
+
96
+ # Final sweep of memory before the next file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  gc.collect()
98
 
99
+ except Exception as e:
100
+ yield f"⚠️ Error processing {file}: {str(e)}\nSkipping to next file..."
 
 
 
 
101
 
102
+ # Clean up the processing directory
103
+ if os.path.exists(TEMP_DIR):
104
+ shutil.rmtree(TEMP_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ yield " All files processed and successfully uploaded to your repository!"
 
 
107
 
108
+ # Build the Gradio Web Interface
109
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
110
+ gr.Markdown("# 🚀 Z-Image-Turbo FP8 Quantizer & Uploader")
111
  gr.Markdown(
112
+ f"This tool sequentially downloads files from `{SOURCE_REPO}`, quantizes the **text_encoder** and **transformer** "
113
+ f"`.safetensors` files to FP8 (`float8_e4m3fn`), and uploads everything to `{TARGET_REPO}`.\n\n"
114
+ "**Note:** Because we are using a free Space (2 vCPUs, 16GB RAM), this script is designed to process one file at a time "
115
+ "and aggressively clear memory/disk caches. It will take some time, but it won't crash."
116
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  with gr.Row():
119
+ with gr.Column(scale=2):
120
+ hf_token = gr.Textbox(
121
+ label="Hugging Face Token (Needs Write Access)",
122
+ type="password",
123
+ placeholder="hf_..."
124
+ )
125
+ start_btn = gr.Button("Start Quantization & Upload", variant="primary")
 
 
 
126
 
127
+ with gr.Column(scale=3):
128
+ output_log = gr.Textbox(
129
+ label="Operation Logs",
130
+ lines=15,
131
+ interactive=False,
132
+ max_lines=20
133
+ )
134
+
135
+ start_btn.click(
136
+ fn=convert_and_upload,
137
+ inputs=[hf_token],
138
+ outputs=[output_log]
139
  )
140
 
141
  if __name__ == "__main__":