dongyi commited on
Commit
13def2b
1 Parent(s): 3adb51d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -220,7 +220,7 @@ def generate_one_shot(src_img, img_prompt):
220
 
221
  def generate_zero_shot(src_img, txt_prompt):
222
  # init model
223
- state_dict = torch.load(f"./checkpoints/{img_prompt[-2:]}/epoch_latest.pth", map_location='cpu')
224
  model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
225
  model.to(device)
226
  model.eval()
@@ -274,6 +274,7 @@ with gr.Blocks() as demo:
274
 
275
  with gr.TabItem("One-Shot"):
276
  one_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
 
277
  with gr.Row():
278
  gr.Image(shape=(100, 100), value = Image.open("example/reference/01.png"), type='pil', label="ref01")
279
  gr.Image(shape=(100, 100), value = Image.open("example/reference/02.png"), type='pil', label="ref02")
@@ -289,6 +290,7 @@ with gr.Blocks() as demo:
289
 
290
  with gr.TabItem("Zero-Shot"):
291
  zero_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
 
292
  zero_shot_ref_prompt = gr.Dropdown(
293
  label="Txt Prompt",
294
  info="Select a reference style prompt",
 
220
 
221
  def generate_zero_shot(src_img, txt_prompt):
222
  # init model
223
+ state_dict = torch.load(f"./checkpoints/{txt_prompt.replace(' ', '_')}/epoch_latest.pth", map_location='cpu')
224
  model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
225
  model.to(device)
226
  model.eval()
 
274
 
275
  with gr.TabItem("One-Shot"):
276
  one_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
277
+ gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=one_shot_src_img)
278
  with gr.Row():
279
  gr.Image(shape=(100, 100), value = Image.open("example/reference/01.png"), type='pil', label="ref01")
280
  gr.Image(shape=(100, 100), value = Image.open("example/reference/02.png"), type='pil', label="ref02")
 
290
 
291
  with gr.TabItem("Zero-Shot"):
292
  zero_shot_src_img = gr.Image(label="Upload Input Face Image", type='filepath', height=400)
293
+ gr.Examples(examples=["./example/source/01.png", "./example/source/02.png", "./example/source/03.png", "./example/source/04.png"], inputs=zero_shot_src_img)
294
  zero_shot_ref_prompt = gr.Dropdown(
295
  label="Txt Prompt",
296
  info="Select a reference style prompt",