jamino30 commited on
Commit
ec7b4ee
·
verified ·
1 Parent(s): e287232

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +14 -11
  2. inference.py +8 -6
app.py CHANGED
@@ -41,7 +41,7 @@ for style_name, style_img_path in style_options.items():
41
  cached_style_features[style_name] = style_features
42
 
43
  @spaces.GPU(duration=10)
44
- def run(content_image, style_name, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
45
  yield None
46
  img_size = 1024 if output_quality else 512
47
  content_img, original_size = preprocess_img(content_image, img_size)
@@ -62,7 +62,9 @@ def run(content_image, style_name, style_strength, output_quality, progress=gr.P
62
  model=model,
63
  content_image=content_img,
64
  style_features=style_features,
65
- lr=converted_lr
 
 
66
  )
67
  et = time.time()
68
  print('TIME TAKEN:', et-st)
@@ -88,7 +90,8 @@ with gr.Blocks(css=css) as demo:
88
  with gr.Column(elem_id='container'):
89
  content_and_output = gr.Image(label='Content', show_label=False, type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
90
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', info='Note: Adjustments automatically optimize for different styles.', value='Starry Night', type='value')
91
- with gr.Accordion('Adjustments', open=False):
 
92
  with gr.Group():
93
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=100, step=1, value=50)
94
 
@@ -98,6 +101,10 @@ with gr.Blocks(css=css) as demo:
98
  high_button = gr.Button('High', size='sm').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
99
  with gr.Group():
100
  output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.')
 
 
 
 
101
 
102
  submit_button = gr.Button('Submit', variant='primary')
103
  download_button = gr.DownloadButton(label='Download Image', visible=False)
@@ -109,7 +116,7 @@ with gr.Blocks(css=css) as demo:
109
 
110
  submit_button.click(
111
  fn=run,
112
- inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality],
113
  outputs=[content_and_output]
114
  ).then(
115
  fn=save_image,
@@ -138,13 +145,9 @@ with gr.Blocks(css=css) as demo:
138
  )
139
 
140
  examples = gr.Examples(
141
- examples=[
142
- ['./content_images/Bridge.jpg', 'Starry Night', *optimal_settings['Starry Night']],
143
- ['./content_images/GoldenRetriever.jpg', 'Lego Bricks', *optimal_settings['Lego Bricks']],
144
- ['./content_images/SeaTurtle.jpg', 'Oil Painting', *optimal_settings['Oil Painting']],
145
- ['./content_images/NYCSkyline.jpg', 'Scream', *optimal_settings['Scream']]
146
- ],
147
- inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality]
148
  )
149
 
150
  demo.queue = False
 
41
  cached_style_features[style_name] = style_features
42
 
43
  @spaces.GPU(duration=10)
44
+ def run(content_image, style_name, style_strength, output_quality, num_iters, optimizer, progress=gr.Progress(track_tqdm=True)):
45
  yield None
46
  img_size = 1024 if output_quality else 512
47
  content_img, original_size = preprocess_img(content_image, img_size)
 
62
  model=model,
63
  content_image=content_img,
64
  style_features=style_features,
65
+ lr=converted_lr,
66
+ iterations=num_iters,
67
+ optim_caller=getattr(torch.optim, optimizer),
68
  )
69
  et = time.time()
70
  print('TIME TAKEN:', et-st)
 
90
  with gr.Column(elem_id='container'):
91
  content_and_output = gr.Image(label='Content', show_label=False, type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
92
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', info='Note: Adjustments automatically optimize for different styles.', value='Starry Night', type='value')
93
+
94
+ with gr.Accordion('Adjustments', open=True):
95
  with gr.Group():
96
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=100, step=1, value=50)
97
 
 
101
  high_button = gr.Button('High', size='sm').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
102
  with gr.Group():
103
  output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.')
104
+
105
+ with gr.Accordion('Advanced Settings', open=False):
106
+ num_iters_slider = gr.Slider(label='Iterations', minimum=1, maximum=50, step=1, value=35)
107
+ optimizer_radio = gr.Radio(label='Optimizer', choices=['Adam', 'AdamW', 'LBFGS'], value='AdamW')
108
 
109
  submit_button = gr.Button('Submit', variant='primary')
110
  download_button = gr.DownloadButton(label='Download Image', visible=False)
 
116
 
117
  submit_button.click(
118
  fn=run,
119
+ inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality, num_iters_slider, optimizer_radio],
120
  outputs=[content_and_output]
121
  ).then(
122
  fn=save_image,
 
145
  )
146
 
147
  examples = gr.Examples(
148
+ label='Example',
149
+ examples=[['./content_images/Bridge.jpg', 'Starry Night', 100, False, 35, 'AdamW']],
150
+ inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality, num_iters_slider, optimizer_radio]
 
 
 
 
151
  )
152
 
153
  demo.queue = False
inference.py CHANGED
@@ -27,22 +27,24 @@ def inference(
27
  style_features,
28
  lr,
29
  iterations=35,
 
30
  alpha=1,
31
  beta=1
32
  ):
33
  generated_image = content_image.clone().requires_grad_(True)
34
- optimizer = optim.AdamW([generated_image], lr=lr)
35
 
36
  with torch.no_grad():
37
  content_features = model(content_image)
38
-
39
- for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
40
  optimizer.zero_grad()
41
-
42
  generated_features = model(generated_image)
43
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
44
-
45
  total_loss.backward()
46
- optimizer.step()
 
 
 
47
 
48
  return generated_image
 
27
  style_features,
28
  lr,
29
  iterations=35,
30
+ optim_caller=optim.AdamW,
31
  alpha=1,
32
  beta=1
33
  ):
34
  generated_image = content_image.clone().requires_grad_(True)
35
+ optimizer = optim_caller([generated_image], lr=lr)
36
 
37
  with torch.no_grad():
38
  content_features = model(content_image)
39
+
40
+ def closure():
41
  optimizer.zero_grad()
 
42
  generated_features = model(generated_image)
43
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
 
44
  total_loss.backward()
45
+ return total_loss
46
+
47
+ for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
48
+ optimizer.step(closure)
49
 
50
  return generated_image