liuhuijie commited on
Commit
424c702
·
1 Parent(s): 399083d
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: CoTyle
3
  emoji: 🎨
4
- colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.1
 
1
  ---
2
  title: CoTyle
3
  emoji: 🎨
4
+ colorFrom: #c8c8c8
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.1
app-ori.py DELETED
@@ -1,310 +0,0 @@
1
- print('v4')
2
- import os
3
- import torch
4
- from PIL import Image
5
- from io import BytesIO
6
- import json
7
- from huggingface_hub import login, hf_hub_download
8
- import spaces
9
- import gradio as gr
10
- token=os.environ.get("HF_TOKEN")
11
-
12
- login(token=os.environ.get("HF_TOKEN"))
13
- REPO_ID = "Kwai-Kolors/cotyle"
14
- # Use GPU if available
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- weight_type = torch.bfloat16 if device == "cuda" else torch.float32
17
-
18
- # Predefined suggested prompts (already in English)
19
- SUGGESTED_PROMPTS = [
20
- "An artist sits outdoors, engrossed in their work, brush in hand, capturing the scene with focused intensity. On the canvas, trees and buildings blend seamlessly with the real-world surroundings. Symbols from different cultures, along with animals, plants, and abstract lines, float around them. As the brush touches the canvas, the paint transforms into points of light that scatter, while sheets of paper and flower petals flutter in the air, creating a sense of movement. The atmosphere is a high-detail fusion of art and reality.",
21
- "Seagulls soar along the seaside under the setting sun, as a couple in wedding attire holds hands.",
22
- "A cute, chubby werewolf holds a balloon and candy, looking adorably mischievous. The background features a full moon on a night sky.",
23
- "A classical beauty, dressed in a dreamy, light pink flowing gown with wide sleeves, adorned with countless tiny wind crystals.",
24
- "The train sped swiftly across a large bridge.",
25
- "In front of the door stands an apple tree with two apples glistening with dewdrops. A beautiful little bird with vibrant feathers perches on a branch, displaying intricate textures and clear details.",
26
- ]
27
-
28
- CUSTOM_OPTION = "✍️ Enter custom prompt..."
29
-
30
- # Lazy load models to avoid slow startup
31
- def load_models():
32
- global pipeline, style_generator, unitok, processor, code_freq
33
- if 'pipeline' in globals():
34
- return # Already loaded
35
-
36
- from models.pipe import CoTylePipeline
37
- from models.vlm_unitok import UniTok
38
- from models.model import StyleGenerator, Qwen2_5_VLForConditionalGeneration_Quant, Qwen2_5_VL_Quant
39
- from models.utils import set_seed, patched_from_model_config
40
- from transformers import Qwen2VLProcessor
41
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
42
- from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
43
- from transformers.generation.configuration_utils import GenerationConfig
44
- _original_from_model_config = GenerationConfig.from_model_config
45
- GenerationConfig.from_model_config = classmethod(patched_from_model_config)
46
-
47
- model_path = "Kwai-Kolors/cotyle"
48
- unitok_config = {
49
- 'unitok_embed_dim': 3584,
50
- 'unitok_vocab_width': 64,
51
- 'unitok_vocab_size': 1024,
52
- 'unitok_e_temp': 0.01,
53
- 'unitok_num_codebooks': 1,
54
- 'unitok_le': 0.0
55
- }
56
-
57
- # Load Style Generator
58
- style_generator_path = hf_hub_download(
59
- repo_id=model_path,
60
- filename='prior',
61
- token=token,
62
- )
63
-
64
- from transformers import AutoConfig
65
- config = AutoConfig.from_pretrained(f"{style_generator_path}/config.json")
66
- style_generator = StyleGenerator._from_config(config)
67
- state_dict = torch.load(f"{style_generator_path}/prior.pth", map_location='cpu')
68
- style_generator.load_state_dict(state_dict)
69
- style_generator.to(device, dtype=weight_type)
70
-
71
- # Load UniTok
72
-
73
- codebook_path = hf_hub_download(
74
- repo_id=model_path,
75
- filename='codebook',
76
- token=token,
77
- )
78
-
79
- unitok = UniTok(unitok_config)
80
- unitok_state_dict = torch.load(f"{codebook_path}/model.pth", map_location='cpu')
81
- unitok.load_state_dict(unitok_state_dict)
82
- unitok.to(device, dtype=weight_type)
83
-
84
- # Load Pipeline (without text encoder initially)
85
- pipeline = CoTylePipeline.from_pretrained(
86
- model_path,
87
- torch_dtype=weight_type,
88
- text_encoder=None,
89
- processor=None,
90
- safety_checker=None,
91
- requires_safety_checker=False
92
- )
93
-
94
- # Load Qwen2.5-VL Text-Visual Encoder
95
- from transformers import Qwen2_5_VLForConditionalGeneration
96
- qwen_text_visual_encoder = Qwen2_5_VLForConditionalGeneration_Quant.from_pretrained(
97
- model_path,
98
- subfolder='text_encoder',
99
- ).to(device, dtype=weight_type)
100
- qwen_text_visual_encoder = Qwen2_5_VL_Quant(unitok, qwen_text_visual_encoder)
101
- qwen_text_visual_encoder.to(device, dtype=weight_type)
102
-
103
- pipeline.text_encoder = qwen_text_visual_encoder
104
-
105
- # Load Processor
106
- processor = Qwen2VLProcessor.from_pretrained(
107
- model_path,
108
- subfolder='processor',
109
- min_pixels=64 * 28 * 28,
110
- max_pixels=256 * 28 * 28
111
- )
112
- pipeline.processor = processor
113
-
114
- pipeline.to(device, dtype=weight_type)
115
- pipeline.set_progress_bar_config(disable=True)
116
-
117
- # Load code frequency
118
- with open(f'{model_path}/freq.json', 'r') as f:
119
- code_freq = json.load(f)
120
-
121
- print("✅ All models loaded successfully!")
122
-
123
-
124
- def get_final_prompt(dropdown_val, text_val):
125
- if dropdown_val == CUSTOM_OPTION:
126
- return text_val.strip()
127
- return dropdown_val.strip() if dropdown_val else ""
128
-
129
- @spaces.GPU
130
- def generate_images(style_code: int, seed: int, num_prompts: int, *args):
131
- load_models()
132
- from models.utils import set_seed
133
-
134
- prompts = []
135
- for i in range(num_prompts):
136
- dropdown_val = args[i * 2] if i * 2 < len(args) else ""
137
- text_val = args[i * 2 + 1] if i * 2 + 1 < len(args) else ""
138
- final_prompt = get_final_prompt(dropdown_val, text_val)
139
- if final_prompt:
140
- prompts.append(final_prompt)
141
-
142
- if not prompts:
143
- raise gr.Error("Please enter at least one valid prompt!")
144
-
145
- # Step 1: Generate style codebook tokens
146
- set_seed(style_code)
147
- style_generator_inputs = {
148
- 'input_ids': torch.randint(low=0, high=1024, size=(1, 1)).to(device),
149
- 'attention_mask': torch.ones((1, 1)).to(device),
150
- }
151
-
152
- with torch.no_grad():
153
- generated_ids = style_generator.generate(
154
- **style_generator_inputs,
155
- max_new_tokens=195,
156
- temperature=1.0,
157
- top_k=200,
158
- top_p=0.95,
159
- do_sample=True,
160
- repetition_penalty=50.0,
161
- code_freq=code_freq,
162
- code_freq_threshold=90000,
163
- k=0.0001,
164
- )
165
-
166
- # Step 2: Generate images
167
- placeholder_image = Image.new("RGB", (392, 392), (0, 0, 0))
168
- results = []
169
-
170
- for i, prompt in enumerate(prompts):
171
- set_seed(seed)
172
-
173
- inputs = {
174
- "image": [placeholder_image],
175
- "prompt": prompt,
176
- "generator": torch.Generator(device=device).manual_seed(seed),
177
- "true_cfg_scale": 6.0,
178
- "negative_prompt": "ugly, monster, grotesque, deformed, mutated, anatomically incorrect, distorted face, disfigured limbs, unnatural posture, blurry, low quality",
179
- "num_inference_steps": 40,
180
- "guidance_scale": 1.0,
181
- "num_images_per_prompt": 1,
182
- "codebook_id": generated_ids,
183
- }
184
-
185
- with torch.inference_mode():
186
- output = pipeline(**inputs)
187
- results.append(output.images[0])
188
-
189
- return results
190
-
191
-
192
- # Gradio Interface
193
- with gr.Blocks(theme=gr.themes.Soft(), css="""
194
- .prompt-hint {
195
- font-size: 0.9em;
196
- color: #666;
197
- margin-top: -8px;
198
- margin-bottom: 12px;
199
- }
200
- """) as demo:
201
- gr.Markdown(
202
- """
203
- <div align="center">
204
-
205
- ## 🎨 CoTyle: Unlocking Code-to-Style Image Generation with Discrete Style Space
206
-
207
- Enter a `style code` and multiple prompts to generate stylized images.
208
-
209
- <p align="center">
210
- <a href="xxx"><img alt="Project Page" src="https://img.shields.io/badge/Project%20Page-Homepage-yellow"></a>
211
- <a href="xxx"><img alt="GitHub" src="https://img.shields.io/badge/GitHub-Code-f8f0f0.svg"></a>
212
- <a href="xxx"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-Paper-da282a.svg"></a>
213
- <a href="xxxK"><img alt="Hugging Face Demo" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-fd8b02"></a>
214
- </p>
215
-
216
- </div>
217
- """
218
- )
219
-
220
- with gr.Row():
221
- with gr.Column():
222
- style_code = gr.Number(label="Style Code", value=1234567, step=1)
223
-
224
- num_prompts = gr.Slider(
225
- minimum=1,
226
- maximum=6,
227
- value=4,
228
- step=1,
229
- label="Number of Prompts (You can choose how many prompt images to generate at once)"
230
- )
231
-
232
- all_dropdowns = []
233
- all_texts = []
234
- prompt_rows = []
235
-
236
- with gr.Column():
237
- for i in range(6):
238
- with gr.Row(visible=(i < 4)) as row:
239
- choices = [""] + SUGGESTED_PROMPTS + [CUSTOM_OPTION]
240
- dropdown = gr.Dropdown(
241
- choices=choices,
242
- value=SUGGESTED_PROMPTS[i] if i < len(SUGGESTED_PROMPTS) else "",
243
- label=f"Prompt {i+1}",
244
- interactive=True
245
- )
246
- text = gr.Textbox(
247
- label=f"Custom Prompt {i+1}",
248
- lines=2,
249
- visible=False
250
- )
251
-
252
- def update_text_visibility(dropdown_val):
253
- return gr.update(visible=(dropdown_val == CUSTOM_OPTION))
254
-
255
- dropdown.change(
256
- fn=update_text_visibility,
257
- inputs=dropdown,
258
- outputs=text
259
- )
260
-
261
- all_dropdowns.append(dropdown)
262
- all_texts.append(text)
263
- prompt_rows.append(row)
264
-
265
- seed = gr.Number(label="Seed", value=42, step=1)
266
- run_btn = gr.Button("✨ Generate All Images", variant="primary", size="lg")
267
-
268
- with gr.Column():
269
- gallery = gr.Gallery(
270
- label="Generated Results",
271
- show_label=True,
272
- columns=2,
273
- rows=2,
274
- object_fit="contain",
275
- height="auto"
276
- )
277
-
278
- # Update visibility of prompt rows
279
- def update_rows_visibility(n):
280
- return [gr.update(visible=(i < n)) for i in range(6)]
281
-
282
- num_prompts.change(
283
- fn=update_rows_visibility,
284
- inputs=num_prompts,
285
- outputs=prompt_rows
286
- )
287
-
288
- # Build input list: [style_code, seed, num_prompts, d1, t1, d2, t2, ...]
289
- input_components = [style_code, seed, num_prompts]
290
- for d, t in zip(all_dropdowns, all_texts):
291
- input_components.extend([d, t])
292
-
293
- run_btn.click(
294
- fn=generate_images,
295
- inputs=input_components,
296
- outputs=gallery
297
- )
298
-
299
- gr.Markdown("""
300
- > **Tips**:
301
- > - Adjust the **Number of Prompts** slider to add or remove input rows.
302
- > - Select **"✍️ Enter custom prompt..."** to type your own prompts.
303
- > - All images share the same `style_code`.
304
- """)
305
-
306
- # Launch
307
- if __name__ == "__main__":
308
- import sys
309
- sys.path.append(".")
310
- demo.queue.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -10,7 +10,8 @@ from PIL import Image
10
  from huggingface_hub import snapshot_download
11
  import gc
12
  import psutil
13
- import os
 
14
  try:
15
  import pynvml
16
  pynvml.nvmlInit()
@@ -24,7 +25,7 @@ try:
24
  except Exception as e:
25
  print("无法获取 GPU 信息:", e)
26
 
27
- REPO_ID = "Kwai-Kolors/cotyle"
28
  HF_TOKEN = os.getenv("HF_TOKEN")
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -38,7 +39,71 @@ SUGGESTED_PROMPTS = [
38
  "The train sped swiftly across a large bridge.",
39
  "In front of the door stands an apple tree with two apples glistening with dewdrops. A beautiful little bird with vibrant feathers perches on a branch, displaying intricate textures and clear details.",
40
  ]
41
- CUSTOM_OPTION = "✍️ Enter custom prompt..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def check_memory_usage(tag):
44
  process = psutil.Process(os.getpid())
@@ -58,18 +123,17 @@ def load_models():
58
  repo_id=REPO_ID,
59
  token=HF_TOKEN,
60
  allow_patterns=[
61
- "prior/**", # 递归下载 prior/ 目录下所有文件
62
- "codebook/**", # 递归下载 codebook/ 目录下所有文件
63
  "tokenizer/**",
64
- "processor/**", # 递归下载 processor/ 目录下所有文件
65
- "text_encoder/**", # 递归下载 text_encoder/ 目录下所有文件
66
- "freq.json", # 明确指定单个文件(可选,也可用 *.json)
67
  "processor/**",
 
 
68
  "transformer/**",
69
  "vae/**",
70
- "*.json", # 所有 .json 文件(包括 config.json 等)
71
- "*.pth", # 所有 .pth 文件
72
- "*.safetensors", # 所有 .safetensors 文件
73
  ],
74
  resume_download=True,
75
  )
@@ -110,7 +174,6 @@ def load_models():
110
  processor=None,
111
  safety_checker=None,
112
  requires_safety_checker=False,
113
-
114
  )
115
  check_memory_usage('before qwen')
116
  qwen_text_visual_encoder = Qwen2_5_VLForConditionalGeneration_Quant.from_pretrained(
@@ -135,11 +198,6 @@ def load_models():
135
  code_freq = json.load(f)
136
  print('='*10, " All models loaded successfully!")
137
 
138
- def get_final_prompt(dropdown_val, text_val):
139
- if dropdown_val == CUSTOM_OPTION:
140
- return (text_val or "").strip()
141
- return (dropdown_val or "").strip()
142
-
143
  @spaces.GPU
144
  def generate_images(style_code, seed, num_prompts, *args):
145
  try:
@@ -154,17 +212,22 @@ def generate_images(style_code, seed, num_prompts, *args):
154
  num_prompts = int(num_prompts)
155
  except Exception:
156
  num_prompts = 1
 
157
  load_models()
158
  from models.utils import set_seed
 
159
  prompts = []
160
  for i in range(num_prompts):
161
- dropdown_val = args[i * 2] if i * 2 < len(args) else ""
162
- text_val = args[i * 2 + 1] if i * 2 + 1 < len(args) else ""
163
- final_prompt = get_final_prompt(dropdown_val, text_val)
164
- if final_prompt:
165
- prompts.append(final_prompt)
 
 
166
  if not prompts:
167
  raise gr.Error("Please enter at least one valid prompt!")
 
168
  set_seed(style_code)
169
  style_generator_inputs = {
170
  "input_ids": torch.randint(low=0, high=1024, size=(1, 1)).to(device),
@@ -187,6 +250,7 @@ def generate_images(style_code, seed, num_prompts, *args):
187
  placeholder_image = Image.new("RGB", (392, 392), (0, 0, 0))
188
  results = []
189
  for i, prompt in enumerate(prompts):
 
190
  set_seed(seed)
191
  inputs = {
192
  "image": [placeholder_image],
@@ -199,22 +263,57 @@ def generate_images(style_code, seed, num_prompts, *args):
199
  "num_images_per_prompt": 1,
200
  "codebook_id": generated_ids,
201
  }
202
- print('='*10, 'before infer')
203
  with torch.inference_mode():
204
  output = pipeline(**inputs)
205
- print('='*10, 'after inference')
206
  results.append(output.images[0])
207
- # output.images[0].save('tmp.png')
208
 
209
  if torch.cuda.is_available():
210
  torch.cuda.empty_cache()
211
  gc.collect()
212
  del output
213
 
 
214
  return results
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  with gr.Blocks(
217
- theme=gr.themes.Soft(),
 
 
218
  css="""
219
  .prompt-hint {
220
  font-size: 0.9em;
@@ -222,69 +321,120 @@ with gr.Blocks(
222
  margin-top: -8px;
223
  margin-bottom: 12px;
224
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  """
226
  ) as demo:
227
- gr.Markdown(
228
  """
229
- <div align="center">
230
-
231
- ## 🎨 CoTyle: Unlocking Code-to-Style Image Generation with Discrete Style Space
232
-
233
  <div style="display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; margin: 15px 0;">
234
  <a href="xxx"><img alt="Project Page" src="https://img.shields.io/badge/Project%20Page-Homepage-yellow"></a>
235
  <a href="xxx"><img alt="GitHub" src="https://img.shields.io/badge/GitHub-Code-f8f0f0.svg"></a>
236
  <a href="xxx"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-Paper-da282a.svg"></a>
237
- <a href="xxxK"><img alt="Hugging Face Demo" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-fd8b02"></a>
238
  </div>
239
-
240
  </div>
241
  """
242
  )
243
 
244
  with gr.Row():
245
  with gr.Column():
246
- style_code = gr.Number(label="Style Code", value=1234567, step=1)
 
 
 
 
 
 
 
 
247
 
248
  num_prompts = gr.Slider(
249
  minimum=1,
250
  maximum=6,
251
- value=1,
252
  step=1,
253
  label="Number of Prompts (You can choose how many prompt images to generate at once)",
254
  )
255
 
256
- all_dropdowns = []
257
- all_texts = []
258
-
259
- with gr.Column():
260
- for i in range(6):
261
- choices = [""] + SUGGESTED_PROMPTS + [CUSTOM_OPTION]
262
- dropdown = gr.Dropdown(
263
- choices=choices,
264
- value=SUGGESTED_PROMPTS[i] if i < len(SUGGESTED_PROMPTS) else "",
265
- label=f"Prompt {i+1}",
266
- interactive=True,
267
- visible=(i < 1),
268
- )
269
- text = gr.Textbox(
270
- label=f"Custom Prompt {i+1}",
271
- lines=2,
272
- visible=False,
273
- )
274
 
275
- def update_text_visibility(dropdown_val):
276
- return gr.update(visible=(dropdown_val == CUSTOM_OPTION))
277
 
278
- dropdown.change(
279
- fn=update_text_visibility,
280
- inputs=dropdown,
281
- outputs=text,
282
- )
283
-
284
- all_dropdowns.append(dropdown)
285
- all_texts.append(text)
286
 
287
- seed = gr.Number(label="Seed", value=42, step=1)
288
  run_btn = gr.Button("✨ Generate All Images", variant="primary", size="lg")
289
 
290
  with gr.Column():
@@ -292,28 +442,32 @@ with gr.Blocks(
292
  label="Generated Results",
293
  show_label=True,
294
  columns=2,
295
- rows=2,
296
  object_fit="contain",
297
- height="auto",
298
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
- def update_components_visibility(n):
301
- updates = []
302
- for i in range(6):
303
- updates.append(gr.update(visible=(i < n)))
304
- for i in range(6):
305
- updates.append(gr.update(visible=False))
306
- return updates
307
-
308
  num_prompts.change(
309
- fn=update_components_visibility,
310
  inputs=num_prompts,
311
- outputs=(all_dropdowns + all_texts),
 
312
  )
313
 
314
- input_components = [style_code, seed, num_prompts]
315
- for d, t in zip(all_dropdowns, all_texts):
316
- input_components.extend([d, t])
317
 
318
  run_btn.click(
319
  fn=generate_images,
@@ -321,16 +475,102 @@ with gr.Blocks(
321
  outputs=gallery,
322
  )
323
 
324
- gr.Markdown(
325
- """
326
- > <strong>Tips</strong>:
327
- > - Adjust the <strong>Number of Prompts</strong> slider to add or remove input rows.
328
- > - Select <strong>"✍️ Enter custom prompt..."</strong> to type your own prompts.
329
- > - All images share the same `style_code`.
330
- """
331
- )
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  if __name__ == "__main__":
335
  load_models()
336
- demo.queue().launch(max_threads=1, share=True)
 
 
 
 
 
 
 
 
10
  from huggingface_hub import snapshot_download
11
  import gc
12
  import psutil
13
+ from functools import partial
14
+
15
  try:
16
  import pynvml
17
  pynvml.nvmlInit()
 
25
  except Exception as e:
26
  print("无法获取 GPU 信息:", e)
27
 
28
+ REPO_ID = "Kwai-Kolors/Kolors-CoTyle"
29
  HF_TOKEN = os.getenv("HF_TOKEN")
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
39
  "The train sped swiftly across a large bridge.",
40
  "In front of the door stands an apple tree with two apples glistening with dewdrops. A beautiful little bird with vibrant feathers perches on a branch, displaying intricate textures and clear details.",
41
  ]
42
+
43
+ # 预设模板配置
44
+ PRESET_TEMPLATES = [
45
+ {
46
+ "name": "--sref 1234567",
47
+ "image_path": "assets/1234567.jpg",
48
+ "style_code": 1234567,
49
+ "seed": 42,
50
+ "prompts": [
51
+ "An artist sits outdoors, engrossed in their work, brush in hand, capturing the scene with focused intensity. On the canvas, trees and buildings blend seamlessly with the real-world surroundings. Symbols from different cultures, along with animals, plants, and abstract lines, float around them. As the brush touches the canvas, the paint transforms into points of light that scatter, while sheets of paper and flower petals flutter in the air, creating a sense of movement. The atmosphere is a high-detail fusion of art and reality.",
52
+ "Seagulls soar along the seaside under the setting sun, as a couple in wedding attire holds hands.",
53
+ "A cute, chubby werewolf holds a balloon and candy, looking adorably mischievous. The background features a full moon on a night sky.",
54
+ "A classical beauty, dressed in a dreamy, light pink flowing gown with wide sleeves, adorned with countless tiny wind crystals.",
55
+ ]
56
+ },
57
+ {
58
+ "name": "--sref 666666666",
59
+ "image_path": "assets/666666666.jpg",
60
+ "style_code": 666666666,
61
+ "seed": 42,
62
+ "prompts": [
63
+ "A chubby, white, curly-furred baby lamb in anime style, with a pink nose and short mouth, stands on grass looking directly at the camera.",
64
+ "A boy with a backpack stands on a mountain peak, bathed in sunlight, with continuous mountain ranges in the background.",
65
+ "Aerial view: distant wind turbines, mountains, a river, heavy snowfall, and four or five people in orange work uniforms and white safety helmets marching in a line through the snow.",
66
+ "A beautiful Chinese woman in ancient red silk attire rides a white horse, holding a red tassel spear, facing an enemy army of thousands; ethereal clouds swirl around her, and behind her stand countless celestial soldiers clad in white armor; documentary photography style."
67
+ ]
68
+ },
69
+ {
70
+ "name": "--sref 886",
71
+ "image_path": "assets/886.jpg",
72
+ "style_code": 886,
73
+ "seed": 42,
74
+ "prompts": [
75
+ "A lovely crystal snake spirit, slender and nimble, wears an exquisite crystal crown atop its head. Its scales are translucent, shimmering like crystal, its eyes are bright and round, and its expression is lively. Its body coils naturally, its tail gracefully curved, its overall posture harmonious and beautiful.",
76
+ "Seagulls soar along the seaside under the setting sun, as a couple in wedding attire holds hands.",
77
+ "A cute, chubby werewolf holds a balloon and candy, looking adorably mischievous. The background features a full moon on a night sky.",
78
+ "The train sped swiftly across a large bridge."
79
+ ]
80
+ },
81
+ {
82
+ "name": "--sref 10241024",
83
+ "image_path": "assets/10241024.jpg",
84
+ "style_code": 10241024,
85
+ "seed": 42,
86
+ "prompts": [
87
+ "An elegant tabby cat steps gracefully through the doorway, its soft paws landing silently on the floor. Its amber eyes scan the surroundings with keen alertness, taking in every detail of the room.",
88
+ "Mickey Mouse appears in the 1920s gangster world, dressed in a long trench coat and a fedora, holding an old-fashioned revolver. The backdrop is a dimly lit Chicago alleyway, where shadows stretch across the cobblestones and the air is thick with the intrigue of the era.",
89
+ "A motorcycle speeds down the highway, the rider clad in black leather, with a biker girl seated behind him. The setting sun glints off the metallic fuel tank, while the rear wheel kicks up a trail of dust. In the background, the desolate road stretches endlessly towards the horizon, framed by the vast wilderness.",
90
+ "A classical beauty, dressed in a dreamy, light pink flowing gown with wide sleeves, adorned with countless tiny wind crystals."
91
+ ]
92
+ },
93
+ {
94
+ "name": "--sref 4396",
95
+ "image_path": "assets/4396.jpg",
96
+ "style_code": 4396,
97
+ "seed": 42,
98
+ "prompts": [
99
+ "A boy and a girl are walking along the lakeside, surrounded by vibrant flowers, lush grass, and verdant trees.",
100
+ "A hazy full moon hangs high in the night sky, with the bustling streets of an ancient town below, adorned with a variety of lanterns that are vibrant and bright.",
101
+ "A cartoon bear with a wide, round mouth and neatly arranged teeth, illustration, mascot, chubby.",
102
+ "A real-life depiction of a warrior goddess is strikingly beautiful, adorned in metallic armor. She has long legs and sports enormous wings, adding to her majestic presence. A crown sits atop her head, and she wields a weapon, poised in a dynamic battle stance."
103
+ ]
104
+ },
105
+ ]
106
+
107
 
108
  def check_memory_usage(tag):
109
  process = psutil.Process(os.getpid())
 
123
  repo_id=REPO_ID,
124
  token=HF_TOKEN,
125
  allow_patterns=[
126
+ "prior/**",
127
+ "codebook/**",
128
  "tokenizer/**",
 
 
 
129
  "processor/**",
130
+ "text_encoder/**",
131
+ "freq.json",
132
  "transformer/**",
133
  "vae/**",
134
+ "*.json",
135
+ "*.pth",
136
+ "*.safetensors",
137
  ],
138
  resume_download=True,
139
  )
 
174
  processor=None,
175
  safety_checker=None,
176
  requires_safety_checker=False,
 
177
  )
178
  check_memory_usage('before qwen')
179
  qwen_text_visual_encoder = Qwen2_5_VLForConditionalGeneration_Quant.from_pretrained(
 
198
  code_freq = json.load(f)
199
  print('='*10, " All models loaded successfully!")
200
 
 
 
 
 
 
201
  @spaces.GPU
202
  def generate_images(style_code, seed, num_prompts, *args):
203
  try:
 
212
  num_prompts = int(num_prompts)
213
  except Exception:
214
  num_prompts = 1
215
+
216
  load_models()
217
  from models.utils import set_seed
218
+
219
  prompts = []
220
  for i in range(num_prompts):
221
+ if i < len(args):
222
+ prompt_text = (args[i] or "").strip()
223
+ if prompt_text:
224
+ prompts.append(prompt_text)
225
+
226
+ print(f"收集到 {len(prompts)} 个有效 prompts")
227
+
228
  if not prompts:
229
  raise gr.Error("Please enter at least one valid prompt!")
230
+
231
  set_seed(style_code)
232
  style_generator_inputs = {
233
  "input_ids": torch.randint(low=0, high=1024, size=(1, 1)).to(device),
 
250
  placeholder_image = Image.new("RGB", (392, 392), (0, 0, 0))
251
  results = []
252
  for i, prompt in enumerate(prompts):
253
+ print(f"正在生成第 {i+1}/{len(prompts)} 张图片")
254
  set_seed(seed)
255
  inputs = {
256
  "image": [placeholder_image],
 
263
  "num_images_per_prompt": 1,
264
  "codebook_id": generated_ids,
265
  }
 
266
  with torch.inference_mode():
267
  output = pipeline(**inputs)
 
268
  results.append(output.images[0])
 
269
 
270
  if torch.cuda.is_available():
271
  torch.cuda.empty_cache()
272
  gc.collect()
273
  del output
274
 
275
+ print(f"成功生成 {len(results)} 张图片")
276
  return results
277
 
278
+ def load_preset_template(template_idx):
279
+ """加载预设模板并返回所有需要更新的组件值"""
280
+ template = PRESET_TEMPLATES[template_idx]
281
+
282
+ outputs = [
283
+ template["style_code"],
284
+ template["seed"],
285
+ 4,
286
+ ]
287
+
288
+ for i in range(4):
289
+ outputs.append(template["prompts"][i])
290
+
291
+ for i in range(2):
292
+ outputs.append("")
293
+
294
+ return tuple(outputs) # 返回 tuple 而不是 list,稍微快一点
295
+
296
+ def create_placeholder_image(text):
297
+ """创建占位符图片"""
298
+ return Image.new('RGB', (300, 200), color=(240, 240, 240))
299
+
300
+ # 使用 Blocks 的 js 参数来加速 UI 更新
301
+ custom_js = """
302
+ function() {
303
+ // 优化 Gradio 的更新性能
304
+ const style = document.createElement('style');
305
+ style.textContent = `
306
+ .gradio-container { transition: none !important; }
307
+ .gr-box { transition: none !important; }
308
+ `;
309
+ document.head.appendChild(style);
310
+ }
311
+ """
312
+
313
  with gr.Blocks(
314
+ # theme=gr.themes.midnight(),
315
+ theme = 'Taithrah/Minimal',
316
+ js=custom_js, # 添加自定义 JS 来禁用不必要的动画
317
  css="""
318
  .prompt-hint {
319
  font-size: 0.9em;
 
321
  margin-top: -8px;
322
  margin-bottom: 12px;
323
  }
324
+ .preset-container {
325
+ border: 2px solid #e0e0e0;
326
+ border-radius: 12px;
327
+ padding: 12px;
328
+ cursor: pointer;
329
+ transition: all 0.3s ease;
330
+ background: white;
331
+ height: 100%;
332
+ display: flex;
333
+ flex-direction: column;
334
+ max-width: 280px;
335
+ margin: 0 auto;
336
+ }
337
+ .preset-container:hover {
338
+ border-color: #2196F3;
339
+ box-shadow: 0 4px 12px rgba(33, 150, 243, 0.2);
340
+ transform: translateY(-2px);
341
+ }
342
+ .preset-image-container {
343
+ width: 100%;
344
+ height: 240px;
345
+ overflow: hidden;
346
+ border-radius: 8px;
347
+ margin-bottom: 1px;
348
+ background: white;
349
+ display: flex;
350
+ align-items: center;
351
+ justify-content: center;
352
+ }
353
+ .preset-image-container img {
354
+ width: 100%;
355
+ height: 100%;
356
+ object-fit: cover;
357
+ }
358
+ .preset-text {
359
+ text-align: center;
360
+ font-weight: bold;
361
+ font-size: 1.0em;
362
+ color: #333;
363
+ padding: 3px 0;
364
+ }
365
+ .preset-row {
366
+ margin-bottom: 10px;
367
+ justify-content: center;
368
+ gap: 15px;
369
+ }
370
+ .preset-section {
371
+ max-width: 1900px;
372
+ margin: 0 auto;
373
+ padding: 0 20px;
374
+ }
375
+ /* 禁用不必要的过渡动画以加速 */
376
+ .gr-box, .gr-form, .gr-input {
377
+ transition: none !important;
378
+ }
379
  """
380
  ) as demo:
381
+ gr.HTML(
382
  """
383
+ <div align="center" style="font-size: 40px;">
384
+ 🎨 CoTyle: Unlocking Code-to-Style Image Generation with Discrete Style Space
 
 
385
  <div style="display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; margin: 15px 0;">
386
  <a href="xxx"><img alt="Project Page" src="https://img.shields.io/badge/Project%20Page-Homepage-yellow"></a>
387
  <a href="xxx"><img alt="GitHub" src="https://img.shields.io/badge/GitHub-Code-f8f0f0.svg"></a>
388
  <a href="xxx"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-Paper-da282a.svg"></a>
389
+ <a href="https://huggingface.co/spaces/Kwai-Kolors/CoTyle"><img alt="Hugging Face Demo" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-fd8b02"></a>
390
  </div>
 
391
  </div>
392
  """
393
  )
394
 
395
  with gr.Row():
396
  with gr.Column():
397
+ # style_code = gr.Number(label="Style Code", value=1234567, step=1)
398
+ style_code = gr.Slider(
399
+ minimum=1,
400
+ maximum=4294967296,
401
+ value=1234567,
402
+ step=1,
403
+ label="Style Code",
404
+ )
405
+
406
 
407
  num_prompts = gr.Slider(
408
  minimum=1,
409
  maximum=6,
410
+ value=4,
411
  step=1,
412
  label="Number of Prompts (You can choose how many prompt images to generate at once)",
413
  )
414
 
415
+ text_inputs = []
416
+ for i in range(6):
417
+ default_prompt = SUGGESTED_PROMPTS[i] if i < len(SUGGESTED_PROMPTS) else ""
418
+ textbox = gr.Textbox(
419
+ value=default_prompt,
420
+ label=f"Prompt {i+1}",
421
+ lines=3,
422
+ max_lines=10,
423
+ placeholder="Enter your prompt here...",
424
+ visible=(i < 4),
425
+ )
426
+ text_inputs.append(textbox)
 
 
 
 
 
 
427
 
428
+ # seed = gr.Number(label="Seed", value=42, step=1)
 
429
 
430
+ seed = gr.Slider(
431
+ minimum=1,
432
+ maximum=4294967296,
433
+ value=42,
434
+ step=1,
435
+ label="Seed",
436
+ )
 
437
 
 
438
  run_btn = gr.Button("✨ Generate All Images", variant="primary", size="lg")
439
 
440
  with gr.Column():
 
442
  label="Generated Results",
443
  show_label=True,
444
  columns=2,
 
445
  object_fit="contain",
446
+ height="100%",
447
  )
448
+ gr.Markdown(
449
+ """
450
+ > <strong>Tips</strong>:
451
+ > - Adjust the <strong>Number of Prompts</strong> slider to add or remove input rows.
452
+ > - Type your own prompts directly in the text boxes.
453
+ > - All images share the same style_code.
454
+ > - You can click any template below to quickly load preset style code and prompts.
455
+ """
456
+ )
457
+ # 优化的可见性更新函数
458
+ def update_textboxes_visibility(n):
459
+ # 使用列表推导式,更快
460
+ return [gr.update(visible=(i < n)) for i in range(6)]
461
 
462
+ # 使用 queue=False 来加速不需要排队的操作
 
 
 
 
 
 
 
463
  num_prompts.change(
464
+ fn=update_textboxes_visibility,
465
  inputs=num_prompts,
466
+ outputs=text_inputs,
467
+ queue=False, # 关键:禁用队列以加速
468
  )
469
 
470
+ input_components = [style_code, seed, num_prompts] + text_inputs
 
 
471
 
472
  run_btn.click(
473
  fn=generate_images,
 
475
  outputs=gallery,
476
  )
477
 
 
 
 
 
 
 
 
 
478
 
479
+ # 预先创建输出组件列表(在循环外)
480
+ output_components = [style_code, seed, num_prompts] + text_inputs
481
+
482
+ # 添加预设模板区域
483
+ with gr.Column(elem_classes="preset-section"):
484
+ gr.Markdown("## 🎯 Examples")
485
+ gr.Markdown("Click any example below to quickly load preset style code, seed, and prompts")
486
+
487
+ # 第一行3个预设
488
+ with gr.Row(elem_classes="preset-row"):
489
+ for i in range(5):
490
+ with gr.Column(scale=1, min_width=250):
491
+ template = PRESET_TEMPLATES[i]
492
+ with gr.Column(elem_classes="preset-container"):
493
+ if os.path.exists(template["image_path"]):
494
+ preset_img = gr.Image(
495
+ value=template["image_path"],
496
+ show_label=False,
497
+ interactive=False,
498
+ container=False,
499
+ height=280,
500
+ elem_classes="preset-image-container"
501
+ )
502
+ else:
503
+ placeholder = create_placeholder_image(template["name"])
504
+ preset_img = gr.Image(
505
+ value=placeholder,
506
+ show_label=False,
507
+ interactive=False,
508
+ container=False,
509
+ height=280,
510
+ elem_classes="preset-image-container"
511
+ )
512
+
513
+ preset_btn = gr.Button(
514
+ value=template["name"],
515
+ variant="secondary",
516
+ size="lg"
517
+ )
518
+
519
+ # 使用 partial 和 queue=False 加速
520
+ preset_btn.click(
521
+ fn=partial(load_preset_template, i),
522
+ inputs=None,
523
+ outputs=output_components,
524
+ queue=False, # 关键:禁用队列
525
+ )
526
+
527
+ # # 第二行3个预设
528
+ # with gr.Row(elem_classes="preset-row"):
529
+ # for i in range(3, 6):
530
+ # with gr.Column(scale=1, min_width=250):
531
+ # template = PRESET_TEMPLATES[i]
532
+ # with gr.Column(elem_classes="preset-container"):
533
+ # if os.path.exists(template["image_path"]):
534
+ # preset_img = gr.Image(
535
+ # value=template["image_path"],
536
+ # show_label=False,
537
+ # interactive=False,
538
+ # container=False,
539
+ # height=280,
540
+ # elem_classes="preset-image-container"
541
+ # )
542
+ # else:
543
+ # placeholder = create_placeholder_image(template["name"])
544
+ # preset_img = gr.Image(
545
+ # value=placeholder,
546
+ # show_label=False,
547
+ # interactive=False,
548
+ # container=False,
549
+ # height=280,
550
+ # elem_classes="preset-image-container"
551
+ # )
552
+
553
+ # preset_btn = gr.Button(
554
+ # value=template["name"],
555
+ # variant="secondary",
556
+ # size="lg"
557
+ # )
558
+
559
+ # # 使用 partial 和 queue=False 加速
560
+ # preset_btn.click(
561
+ # fn=partial(load_preset_template, i),
562
+ # inputs=None,
563
+ # outputs=output_components,
564
+ # queue=False, # 关键:禁用队列
565
+ # )
566
 
567
  if __name__ == "__main__":
568
  load_models()
569
+ # 调整 queue 参数以优化性能
570
+ demo.queue(
571
+ max_size=20, # 减小队列大小
572
+ default_concurrency_limit=1
573
+ ).launch(
574
+ max_threads=1,
575
+ share=True
576
+ )
assets/10241024.jpg ADDED

Git LFS Details

  • SHA256: 8de8dec69d93d9b6092f68847249ce0c9e5c3257a71bd96ed6b36bbb53aadca7
  • Pointer size: 131 Bytes
  • Size of remote file: 584 kB
assets/1234567.jpg ADDED

Git LFS Details

  • SHA256: f4e227b8d193e7f4ec91e5c9ed3bdb587a098b615e8a3dca96754b3fde06398f
  • Pointer size: 131 Bytes
  • Size of remote file: 414 kB
assets/4396.jpg ADDED

Git LFS Details

  • SHA256: 5f05e335b46ef9a405279e123f00b02e6100e2f5999457cef1599ba134c6c9e3
  • Pointer size: 131 Bytes
  • Size of remote file: 757 kB
assets/666666666.jpg ADDED

Git LFS Details

  • SHA256: f3641e3651354992bb7889fd8552db5d5a4b9e1b4048c38e2acfebdbf2781023
  • Pointer size: 131 Bytes
  • Size of remote file: 653 kB
assets/886.jpg ADDED

Git LFS Details

  • SHA256: 6f75b102f5c6afcf6691d846a6033b431b69a2b1cf3724a55fee2967cb18448f
  • Pointer size: 131 Bytes
  • Size of remote file: 630 kB
models/__pycache__/model.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/model.cpython-310.pyc and b/models/__pycache__/model.cpython-310.pyc differ
 
models/__pycache__/pipe.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/pipe.cpython-310.pyc and b/models/__pycache__/pipe.cpython-310.pyc differ
 
models/__pycache__/quant.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/quant.cpython-310.pyc and b/models/__pycache__/quant.cpython-310.pyc differ
 
models/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/utils.cpython-310.pyc and b/models/__pycache__/utils.cpython-310.pyc differ
 
models/__pycache__/vitamin.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/vitamin.cpython-310.pyc and b/models/__pycache__/vitamin.cpython-310.pyc differ
 
models/__pycache__/vlm_unitok.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/vlm_unitok.cpython-310.pyc and b/models/__pycache__/vlm_unitok.cpython-310.pyc differ
 
models/__pycache__/vqvae.cpython-310.pyc CHANGED
Binary files a/models/__pycache__/vqvae.cpython-310.pyc and b/models/__pycache__/vqvae.cpython-310.pyc differ