jamino30 commited on
Commit
d879848
1 Parent(s): 0803fb8

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. inference.py +4 -4
app.py CHANGED
@@ -81,7 +81,7 @@ def run(content_image, style_name, style_strength=5):
81
  future_all = executor.submit(run_inference, False)
82
  future_bg = executor.submit(run_inference, True)
83
  generated_img_all, _ = future_all.result()
84
- generated_img_bg, bg_ratio = future_bg.result()
85
 
86
  et = time.time()
87
  print('TIME TAKEN:', et-st)
@@ -89,7 +89,7 @@ def run(content_image, style_name, style_strength=5):
89
  yield (
90
  (content_image, postprocess_img(generated_img_all, original_size)),
91
  (content_image, postprocess_img(generated_img_bg, original_size)),
92
- f'{bg_ratio:.2f}'
93
  )
94
 
95
  def set_slider(value):
@@ -126,7 +126,7 @@ with gr.Blocks(css=css) as demo:
126
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
127
  with gr.Group():
128
  output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
129
- bg_ratio_label = gr.Label(label='Background Ratio')
130
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
131
 
132
  def save_image(img_tuple1, img_tuple2):
@@ -143,7 +143,7 @@ with gr.Blocks(css=css) as demo:
143
  submit_button.click(
144
  fn=run,
145
  inputs=[content_image, style_dropdown, style_strength_slider],
146
- outputs=[output_image_all, output_image_background, bg_ratio_label]
147
  ).then(
148
  fn=save_image,
149
  inputs=[output_image_all, output_image_background],
 
81
  future_all = executor.submit(run_inference, False)
82
  future_bg = executor.submit(run_inference, True)
83
  generated_img_all, _ = future_all.result()
84
+ generated_img_bg, salient_object_ratio = future_bg.result()
85
 
86
  et = time.time()
87
  print('TIME TAKEN:', et-st)
 
89
  yield (
90
  (content_image, postprocess_img(generated_img_all, original_size)),
91
  (content_image, postprocess_img(generated_img_bg, original_size)),
92
+ f'{salient_object_ratio:.2f}'
93
  )
94
 
95
  def set_slider(value):
 
126
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
127
  with gr.Group():
128
  output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
129
+ salient_object_ratio_label = gr.Label(label='Salient Object Ratio')
130
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
131
 
132
  def save_image(img_tuple1, img_tuple2):
 
143
  submit_button.click(
144
  fn=run,
145
  inputs=[content_image, style_dropdown, style_strength_slider],
146
+ outputs=[output_image_all, output_image_background, salient_object_ratio_label]
147
  ).then(
148
  fn=save_image,
149
  inputs=[output_image_all, output_image_background],
inference.py CHANGED
@@ -60,16 +60,16 @@ def inference(
60
  content_features = model(content_image)
61
 
62
  resized_bg_masks = []
63
- background_ratio = None
64
  if apply_to_background:
65
  segmentation_output = segmentation_model(content_image)['out']
66
  segmentation_mask = segmentation_output.argmax(dim=1)
67
  background_mask = (segmentation_mask == 0).float()
68
  foreground_mask = 1 - background_mask
69
 
70
- background_pixel_count = background_mask.sum().item()
71
  total_pixel_count = segmentation_mask.numel()
72
- background_ratio = background_pixel_count / total_pixel_count
73
 
74
  for cf in content_features:
75
  _, _, h_i, w_i = cf.shape
@@ -106,4 +106,4 @@ def inference(
106
  if DEV_MODE:
107
  writer.flush()
108
  writer.close()
109
- return generated_image, background_ratio
 
60
  content_features = model(content_image)
61
 
62
  resized_bg_masks = []
63
+ salient_object_ratio = None
64
  if apply_to_background:
65
  segmentation_output = segmentation_model(content_image)['out']
66
  segmentation_mask = segmentation_output.argmax(dim=1)
67
  background_mask = (segmentation_mask == 0).float()
68
  foreground_mask = 1 - background_mask
69
 
70
+ salient_object_pixel_count = foreground_mask.sum().item()
71
  total_pixel_count = segmentation_mask.numel()
72
+ salient_object_ratio = salient_object_pixel_count / total_pixel_count
73
 
74
  for cf in content_features:
75
  _, _, h_i, w_i = cf.shape
 
106
  if DEV_MODE:
107
  writer.flush()
108
  writer.close()
109
+ return generated_image, salient_object_ratio