jamino30 commited on
Commit
e9e9628
1 Parent(s): 880c0d9

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +12 -17
  2. inference.py +6 -8
app.py CHANGED
@@ -16,8 +16,6 @@ else: device = 'cpu'
16
  print('DEVICE:', device)
17
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
18
 
19
- torch.backends.cuda.matmul.allow_tf32 = False
20
-
21
  model = VGG_19().to(device).eval()
22
  for param in model.parameters():
23
  param.requires_grad = False
@@ -43,7 +41,7 @@ for style_name, style_img_path in style_options.items():
43
  cached_style_features[style_name] = style_features
44
 
45
  @spaces.GPU(duration=10)
46
- def run(content_image, style_name, style_strength, output_quality, num_iters, optimizer, progress=gr.Progress(track_tqdm=True)):
47
  yield None
48
  img_size = 1024 if output_quality else 512
49
  content_img, original_size = preprocess_img(content_image, img_size)
@@ -64,9 +62,7 @@ def run(content_image, style_name, style_strength, output_quality, num_iters, op
64
  model=model,
65
  content_image=content_img,
66
  style_features=style_features,
67
- lr=converted_lr,
68
- iterations=num_iters,
69
- optim_caller=getattr(torch.optim, optimizer),
70
  )
71
  et = time.time()
72
  print('TIME TAKEN:', et-st)
@@ -92,8 +88,7 @@ with gr.Blocks(css=css) as demo:
92
  with gr.Column(elem_id='container'):
93
  content_and_output = gr.Image(label='Content', show_label=False, type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
94
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', info='Note: Adjustments automatically optimize for different styles.', value='Starry Night', type='value')
95
-
96
- with gr.Accordion('Adjustments', open=True):
97
  with gr.Group():
98
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=100, step=1, value=50)
99
 
@@ -103,10 +98,6 @@ with gr.Blocks(css=css) as demo:
103
  high_button = gr.Button('High', size='sm').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
104
  with gr.Group():
105
  output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.')
106
-
107
- with gr.Accordion('Advanced Settings', open=False):
108
- num_iters_slider = gr.Slider(label='Iterations', minimum=1, maximum=50, step=1, value=35)
109
- optimizer_radio = gr.Radio(label='Optimizer', choices=['Adam', 'AdamW', 'LBFGS'], value='AdamW')
110
 
111
  submit_button = gr.Button('Submit', variant='primary')
112
  download_button = gr.DownloadButton(label='Download Image', visible=False)
@@ -118,7 +109,7 @@ with gr.Blocks(css=css) as demo:
118
 
119
  submit_button.click(
120
  fn=run,
121
- inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality, num_iters_slider, optimizer_radio],
122
  outputs=[content_and_output]
123
  ).then(
124
  fn=save_image,
@@ -147,11 +138,15 @@ with gr.Blocks(css=css) as demo:
147
  )
148
 
149
  examples = gr.Examples(
150
- label='Example',
151
- examples=[['./content_images/Bridge.jpg', 'Starry Night', 100, False, 35, 'AdamW']],
152
- inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality, num_iters_slider, optimizer_radio]
 
 
 
 
153
  )
154
 
155
  demo.queue = False
156
  demo.config['queue'] = False
157
- demo.launch(show_api=False)
 
16
  print('DEVICE:', device)
17
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
18
 
 
 
19
  model = VGG_19().to(device).eval()
20
  for param in model.parameters():
21
  param.requires_grad = False
 
41
  cached_style_features[style_name] = style_features
42
 
43
  @spaces.GPU(duration=10)
44
+ def run(content_image, style_name, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
45
  yield None
46
  img_size = 1024 if output_quality else 512
47
  content_img, original_size = preprocess_img(content_image, img_size)
 
62
  model=model,
63
  content_image=content_img,
64
  style_features=style_features,
65
+ lr=converted_lr
 
 
66
  )
67
  et = time.time()
68
  print('TIME TAKEN:', et-st)
 
88
  with gr.Column(elem_id='container'):
89
  content_and_output = gr.Image(label='Content', show_label=False, type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
90
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', info='Note: Adjustments automatically optimize for different styles.', value='Starry Night', type='value')
91
+ with gr.Accordion('Adjustments', open=False):
 
92
  with gr.Group():
93
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=100, step=1, value=50)
94
 
 
98
  high_button = gr.Button('High', size='sm').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
99
  with gr.Group():
100
  output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.')
 
 
 
 
101
 
102
  submit_button = gr.Button('Submit', variant='primary')
103
  download_button = gr.DownloadButton(label='Download Image', visible=False)
 
109
 
110
  submit_button.click(
111
  fn=run,
112
+ inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality],
113
  outputs=[content_and_output]
114
  ).then(
115
  fn=save_image,
 
138
  )
139
 
140
  examples = gr.Examples(
141
+ examples=[
142
+ ['./content_images/Bridge.jpg', 'Starry Night', *optimal_settings['Starry Night']],
143
+ ['./content_images/GoldenRetriever.jpg', 'Lego Bricks', *optimal_settings['Lego Bricks']],
144
+ ['./content_images/SeaTurtle.jpg', 'Oil Painting', *optimal_settings['Oil Painting']],
145
+ ['./content_images/NYCSkyline.jpg', 'Scream', *optimal_settings['Scream']]
146
+ ],
147
+ inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality]
148
  )
149
 
150
  demo.queue = False
151
  demo.config['queue'] = False
152
+ demo.launch(show_api=False)
inference.py CHANGED
@@ -27,24 +27,22 @@ def inference(
27
  style_features,
28
  lr,
29
  iterations=35,
30
- optim_caller=optim.AdamW,
31
  alpha=1,
32
  beta=1
33
  ):
34
  generated_image = content_image.clone().requires_grad_(True)
35
- optimizer = optim_caller([generated_image], lr=lr)
36
 
37
  with torch.no_grad():
38
  content_features = model(content_image)
39
-
40
- def closure():
41
  optimizer.zero_grad()
 
42
  generated_features = model(generated_image)
43
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
 
44
  total_loss.backward()
45
- return total_loss
46
-
47
- for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
48
- optimizer.step(closure)
49
 
50
  return generated_image
 
27
  style_features,
28
  lr,
29
  iterations=35,
 
30
  alpha=1,
31
  beta=1
32
  ):
33
  generated_image = content_image.clone().requires_grad_(True)
34
+ optimizer = optim.AdamW([generated_image], lr=lr)
35
 
36
  with torch.no_grad():
37
  content_features = model(content_image)
38
+
39
+ for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
40
  optimizer.zero_grad()
41
+
42
  generated_features = model(generated_image)
43
  total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
44
+
45
  total_loss.backward()
46
+ optimizer.step()
 
 
 
47
 
48
  return generated_image