Spaces:
ginipick
/
Running on Zero

ginipick commited on
Commit
7ac3a5d
โ€ข
1 Parent(s): 1f17448

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -23
app.py CHANGED
@@ -175,29 +175,46 @@ def contains_korean(text):
175
  return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
176
 
177
 
178
- # ๋ฉ”์ธ ๊ธฐ๋Šฅ ํ•จ์ˆ˜๋“ค
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  @spaces.GPU()
180
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
181
  try:
182
  # ํ•œ๊ธ€ ์ฒ˜๋ฆฌ
183
  if contains_korean(prompt):
184
- translator = get_translator()
185
- translated = translator(prompt)[0]['translation_text']
186
- actual_prompt = translated
 
187
  else:
188
  actual_prompt = prompt
189
 
190
- # ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ€์ ธ์˜ค๊ธฐ
191
  pipe = initialize_fashion_pipe()
192
 
193
  # LoRA ์„ค์ •
194
  if mode == "Generate Model":
195
- pipe = load_lora(pipe, MODEL_LORA_REPO)
196
  trigger_word = "fashion photography, professional model"
197
  else:
198
- pipe = load_lora(pipe, CLOTHES_LORA_REPO)
199
  trigger_word = "upper clothing, fashion item"
200
-
201
  # ํŒŒ๋ผ๋ฏธํ„ฐ ์ œํ•œ
202
  width = min(width, 768)
203
  height = min(height, 768)
@@ -206,27 +223,32 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
206
  # ์‹œ๋“œ ์„ค์ •
207
  if randomize_seed:
208
  seed = random.randint(0, MAX_SEED)
209
- generator = torch.Generator(device="cuda").manual_seed(seed)
210
-
211
- # ์ง„ํ–‰๋ฅ  ํ‘œ์‹œ
212
- progress(0, "Starting fashion generation...")
213
 
214
  # ์ด๋ฏธ์ง€ ์ƒ์„ฑ
215
- image = pipe(
216
- prompt=f"{actual_prompt} {trigger_word}",
217
- num_inference_steps=steps,
218
- guidance_scale=cfg_scale,
219
- width=width,
220
- height=height,
221
- generator=generator,
222
- joint_attention_kwargs={"scale": lora_scale},
223
- ).images[0]
 
 
 
 
 
 
 
 
224
 
225
  return image, seed
226
-
227
  except Exception as e:
228
  print(f"Error in generate_fashion: {str(e)}")
229
- raise
230
 
231
  @safe_model_call
232
  def leffa_predict(src_image_path, ref_image_path, control_type):
 
175
  return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
176
 
177
 
178
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ ์ˆ˜์ •
179
+ @spaces.GPU()
180
+ def initialize_fashion_pipe():
181
+ try:
182
+ pipe = DiffusionPipeline.from_pretrained(
183
+ BASE_MODEL,
184
+ torch_dtype=torch.float16,
185
+ safety_checker=None,
186
+ requires_safety_checker=False
187
+ ).to("cuda")
188
+ pipe.enable_model_cpu_offload()
189
+ return pipe
190
+ except Exception as e:
191
+ print(f"Error initializing fashion pipe: {e}")
192
+ raise
193
+
194
+ # ์ƒ์„ฑ ํ•จ์ˆ˜ ์ˆ˜์ •
195
  @spaces.GPU()
196
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
197
  try:
198
  # ํ•œ๊ธ€ ์ฒ˜๋ฆฌ
199
  if contains_korean(prompt):
200
+ with torch.inference_mode():
201
+ translator = get_translator()
202
+ translated = translator(prompt)[0]['translation_text']
203
+ actual_prompt = translated
204
  else:
205
  actual_prompt = prompt
206
 
207
+ # ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
208
  pipe = initialize_fashion_pipe()
209
 
210
  # LoRA ์„ค์ •
211
  if mode == "Generate Model":
212
+ pipe.load_lora_weights(MODEL_LORA_REPO)
213
  trigger_word = "fashion photography, professional model"
214
  else:
215
+ pipe.load_lora_weights(CLOTHES_LORA_REPO)
216
  trigger_word = "upper clothing, fashion item"
217
+
218
  # ํŒŒ๋ผ๋ฏธํ„ฐ ์ œํ•œ
219
  width = min(width, 768)
220
  height = min(height, 768)
 
223
  # ์‹œ๋“œ ์„ค์ •
224
  if randomize_seed:
225
  seed = random.randint(0, MAX_SEED)
226
+ generator = torch.Generator("cuda").manual_seed(seed)
 
 
 
227
 
228
  # ์ด๋ฏธ์ง€ ์ƒ์„ฑ
229
+ with torch.inference_mode():
230
+ output = pipe(
231
+ prompt=f"{actual_prompt} {trigger_word}",
232
+ num_inference_steps=steps,
233
+ guidance_scale=cfg_scale,
234
+ width=width,
235
+ height=height,
236
+ generator=generator,
237
+ cross_attention_kwargs={"scale": lora_scale},
238
+ )
239
+
240
+ image = output.images[0]
241
+
242
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
243
+ del pipe
244
+ torch.cuda.empty_cache()
245
+ gc.collect()
246
 
247
  return image, seed
248
+
249
  except Exception as e:
250
  print(f"Error in generate_fashion: {str(e)}")
251
+ raise gr.Error(f"Generation failed: {str(e)}")
252
 
253
  @safe_model_call
254
  def leffa_predict(src_image_path, ref_image_path, control_type):