rootlocalghost commited on
Commit
281d59c
Β·
verified Β·
1 Parent(s): 98c9be7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -40
app.py CHANGED
@@ -6,78 +6,84 @@ import gradio as gr
6
  from huggingface_hub import HfApi, hf_hub_download
7
  from safetensors.torch import load_file, save_file
8
 
9
- SOURCE_REPO = "Tongyi-MAI/Z-Image-Turbo"
10
- TARGET_REPO = "rootlocalghost/Z-Image-Turbo-FP8"
11
  TEMP_DIR = "temp_processing_dir"
12
 
13
- def convert_and_upload(token):
14
  if not token:
15
  yield "❌ Error: Please provide a valid Hugging Face Write Token."
16
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  api = HfApi(token=token)
19
- yield f"πŸ”„ Connecting to Hugging Face and verifying target repo: {TARGET_REPO}..."
20
 
21
- # Ensure the target repo exists, create it if it doesn't
22
  try:
23
- api.create_repo(repo_id=TARGET_REPO, exist_ok=True, private=False)
24
  except Exception as e:
25
  yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
26
  return
27
 
28
- yield "πŸ“‹ Fetching file list from the source repository..."
29
  try:
30
- files = api.list_repo_files(SOURCE_REPO)
31
  except Exception as e:
32
  yield f"❌ Error fetching files: {str(e)}"
33
  return
34
 
35
- # Create a temporary directory for safe local processing
36
  os.makedirs(TEMP_DIR, exist_ok=True)
37
 
38
  for file in files:
39
  yield f"⏳ Processing {file}..."
40
 
41
  try:
42
- # Download file locally without using the central symlink cache
43
- # This is critical to prevent the 50GB Space disk from filling up
44
  local_path = hf_hub_download(
45
- repo_id=SOURCE_REPO,
46
  filename=file,
47
  local_dir=TEMP_DIR,
48
  local_dir_use_symlinks=False
49
  )
50
 
51
- # Check if it's a safetensor file inside the target directories
52
  if file.endswith(".safetensors") and ("text_encoder/" in file or "transformer/" in file):
53
- yield f"🧠 Quantizing {file} to FP8 (This may take a minute)..."
54
 
55
- # Load tensors into RAM
56
  tensors = load_file(local_path)
57
 
58
- # Cast all floating point tensors to FP8
59
- keys = list(tensors.keys())
60
- for k in keys:
61
- if tensors[k].is_floating_point():
62
- tensors[k] = tensors[k].to(torch.float8_e4m3fn)
 
63
 
64
- # Save the quantized tensors to a new temp file
65
  converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
66
  save_file(tensors, converted_path)
67
 
68
- # Wipe the tensors from RAM immediately to stay under the 16GB limit
69
  del tensors
70
  gc.collect()
71
 
72
- yield f"☁️ Uploading FP8 version of {file}..."
73
  api.upload_file(
74
  path_or_fileobj=converted_path,
75
  path_in_repo=file,
76
- repo_id=TARGET_REPO,
77
- commit_message=f"Upload FP8 quantized {file}"
78
  )
79
 
80
- # Clean up the converted file
81
  os.remove(converted_path)
82
 
83
  else:
@@ -85,40 +91,55 @@ def convert_and_upload(token):
85
  api.upload_file(
86
  path_or_fileobj=local_path,
87
  path_in_repo=file,
88
- repo_id=TARGET_REPO,
89
  commit_message=f"Copy {file} from original repo"
90
  )
91
 
92
- # Delete the downloaded original file to free up disk space
93
  if os.path.exists(local_path):
94
  os.remove(local_path)
95
 
96
- # Final sweep of memory before the next file
97
  gc.collect()
98
 
99
  except Exception as e:
100
  yield f"⚠️ Error processing {file}: {str(e)}\nSkipping to next file..."
101
 
102
- # Clean up the processing directory
103
  if os.path.exists(TEMP_DIR):
104
  shutil.rmtree(TEMP_DIR)
105
 
106
- yield "βœ… All files processed and successfully uploaded to your repository!"
107
 
108
- # Build the Gradio Web Interface
 
 
 
 
 
109
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
110
- gr.Markdown("# πŸš€ Z-Image-Turbo FP8 Quantizer & Uploader")
111
  gr.Markdown(
112
- f"This tool sequentially downloads files from `{SOURCE_REPO}`, quantizes the **text_encoder** and **transformer** "
113
- f"`.safetensors` files to FP8 (`float8_e4m3fn`), and uploads everything to `{TARGET_REPO}`.\n\n"
114
- "**Note:** Because we are using a free Space (2 vCPUs, 16GB RAM), this script is designed to process one file at a time "
115
- "and aggressively clear memory/disk caches. It will take some time, but it won't crash."
116
  )
117
 
118
  with gr.Row():
119
  with gr.Column(scale=2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  hf_token = gr.Textbox(
121
- label="Hugging Face Token (Needs Write Access)",
122
  type="password",
123
  placeholder="hf_..."
124
  )
@@ -127,14 +148,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
127
  with gr.Column(scale=3):
128
  output_log = gr.Textbox(
129
  label="Operation Logs",
130
- lines=15,
131
  interactive=False,
132
  max_lines=20
133
  )
134
 
 
 
 
 
135
  start_btn.click(
136
  fn=convert_and_upload,
137
- inputs=[hf_token],
138
  outputs=[output_log]
139
  )
140
 
 
6
  from huggingface_hub import HfApi, hf_hub_download
7
  from safetensors.torch import load_file, save_file
8
 
 
 
9
  TEMP_DIR = "temp_processing_dir"
10
 
11
+ def convert_and_upload(token, source_repo, target_repo, precision):
12
  if not token:
13
  yield "❌ Error: Please provide a valid Hugging Face Write Token."
14
  return
15
+ if not target_repo.strip():
16
+ yield "❌ Error: Please specify a Target Repository."
17
+ return
18
+
19
+ # Map precision string to PyTorch dtype
20
+ if precision == "FP8":
21
+ target_dtype = torch.float8_e4m3fn
22
+ elif precision == "FP16":
23
+ target_dtype = torch.float16
24
+ elif precision == "BF16":
25
+ target_dtype = torch.bfloat16
26
+ else:
27
+ target_dtype = None
28
 
29
  api = HfApi(token=token)
30
+ yield f"πŸ”„ Connecting to Hugging Face and verifying target repo: {target_repo}..."
31
 
 
32
  try:
33
+ api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
34
  except Exception as e:
35
  yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
36
  return
37
 
38
+ yield f"πŸ“‹ Fetching file list from {source_repo}..."
39
  try:
40
+ files = api.list_repo_files(source_repo)
41
  except Exception as e:
42
  yield f"❌ Error fetching files: {str(e)}"
43
  return
44
 
 
45
  os.makedirs(TEMP_DIR, exist_ok=True)
46
 
47
  for file in files:
48
  yield f"⏳ Processing {file}..."
49
 
50
  try:
51
+ # Download file locally, bypassing symlink cache to save space
 
52
  local_path = hf_hub_download(
53
+ repo_id=source_repo,
54
  filename=file,
55
  local_dir=TEMP_DIR,
56
  local_dir_use_symlinks=False
57
  )
58
 
59
+ # Check if it's a target safetensor file
60
  if file.endswith(".safetensors") and ("text_encoder/" in file or "transformer/" in file):
61
+ yield f"🧠 Quantizing {file} to {precision}..."
62
 
 
63
  tensors = load_file(local_path)
64
 
65
+ # Cast floating point tensors to the selected precision
66
+ if target_dtype:
67
+ keys = list(tensors.keys())
68
+ for k in keys:
69
+ if tensors[k].is_floating_point():
70
+ tensors[k] = tensors[k].to(target_dtype)
71
 
 
72
  converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
73
  save_file(tensors, converted_path)
74
 
75
+ # Wipe tensors from RAM
76
  del tensors
77
  gc.collect()
78
 
79
+ yield f"☁️ Uploading {precision} version of {file}..."
80
  api.upload_file(
81
  path_or_fileobj=converted_path,
82
  path_in_repo=file,
83
+ repo_id=target_repo,
84
+ commit_message=f"Upload {precision} quantized {file}"
85
  )
86
 
 
87
  os.remove(converted_path)
88
 
89
  else:
 
91
  api.upload_file(
92
  path_or_fileobj=local_path,
93
  path_in_repo=file,
94
+ repo_id=target_repo,
95
  commit_message=f"Copy {file} from original repo"
96
  )
97
 
98
+ # Cleanup original downloaded file
99
  if os.path.exists(local_path):
100
  os.remove(local_path)
101
 
 
102
  gc.collect()
103
 
104
  except Exception as e:
105
  yield f"⚠️ Error processing {file}: {str(e)}\nSkipping to next file..."
106
 
 
107
  if os.path.exists(TEMP_DIR):
108
  shutil.rmtree(TEMP_DIR)
109
 
110
+ yield f"βœ… All files processed and successfully uploaded to {target_repo}!"
111
 
112
+ # Dynamic UI Update for Target Repo Name
113
+ def update_target_repo(source, precision):
114
+ model_name = "Z-Image-Turbo" if "Turbo" in source else "Z-Image-Base"
115
+ return f"rootlocalghost/{model_name}-{precision}"
116
+
117
+ # Build the Gradio UI
118
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
119
+ gr.Markdown("# πŸš€ Z-Image Quantizer & Uploader")
120
  gr.Markdown(
121
+ "Select your source model and desired precision. The tool will sequentially download, quantize the "
122
+ "**text_encoder** and **transformer** files, and upload everything to your target repository while keeping memory usage under 16GB."
 
 
123
  )
124
 
125
  with gr.Row():
126
  with gr.Column(scale=2):
127
+ source_repo = gr.Dropdown(
128
+ choices=["Tongyi-MAI/Z-Image", "Tongyi-MAI/Z-Image-Turbo"],
129
+ value="Tongyi-MAI/Z-Image-Turbo",
130
+ label="Source Repository"
131
+ )
132
+ precision = gr.Dropdown(
133
+ choices=["FP8", "FP16", "BF16"],
134
+ value="FP8",
135
+ label="Quantization Precision"
136
+ )
137
+ target_repo = gr.Textbox(
138
+ label="Target Repository",
139
+ value="rootlocalghost/Z-Image-Turbo-FP8"
140
+ )
141
  hf_token = gr.Textbox(
142
+ label="Hugging Face Token (Write Access)",
143
  type="password",
144
  placeholder="hf_..."
145
  )
 
148
  with gr.Column(scale=3):
149
  output_log = gr.Textbox(
150
  label="Operation Logs",
151
+ lines=17,
152
  interactive=False,
153
  max_lines=20
154
  )
155
 
156
+ # Automatically update the target repo name when inputs change
157
+ source_repo.change(fn=update_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
158
+ precision.change(fn=update_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
159
+
160
  start_btn.click(
161
  fn=convert_and_upload,
162
+ inputs=[hf_token, source_repo, target_repo, precision],
163
  outputs=[output_log]
164
  )
165