Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- app.py +14 -11
- inference.py +8 -6
app.py
CHANGED
@@ -41,7 +41,7 @@ for style_name, style_img_path in style_options.items():
|
|
41 |
cached_style_features[style_name] = style_features
|
42 |
|
43 |
@spaces.GPU(duration=10)
|
44 |
-
def run(content_image, style_name, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
|
45 |
yield None
|
46 |
img_size = 1024 if output_quality else 512
|
47 |
content_img, original_size = preprocess_img(content_image, img_size)
|
@@ -62,7 +62,9 @@ def run(content_image, style_name, style_strength, output_quality, progress=gr.P
|
|
62 |
model=model,
|
63 |
content_image=content_img,
|
64 |
style_features=style_features,
|
65 |
-
lr=converted_lr
|
|
|
|
|
66 |
)
|
67 |
et = time.time()
|
68 |
print('TIME TAKEN:', et-st)
|
@@ -88,7 +90,8 @@ with gr.Blocks(css=css) as demo:
|
|
88 |
with gr.Column(elem_id='container'):
|
89 |
content_and_output = gr.Image(label='Content', show_label=False, type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
|
90 |
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', info='Note: Adjustments automatically optimize for different styles.', value='Starry Night', type='value')
|
91 |
-
|
|
|
92 |
with gr.Group():
|
93 |
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=100, step=1, value=50)
|
94 |
|
@@ -98,6 +101,10 @@ with gr.Blocks(css=css) as demo:
|
|
98 |
high_button = gr.Button('High', size='sm').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
|
99 |
with gr.Group():
|
100 |
output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.')
|
|
|
|
|
|
|
|
|
101 |
|
102 |
submit_button = gr.Button('Submit', variant='primary')
|
103 |
download_button = gr.DownloadButton(label='Download Image', visible=False)
|
@@ -109,7 +116,7 @@ with gr.Blocks(css=css) as demo:
|
|
109 |
|
110 |
submit_button.click(
|
111 |
fn=run,
|
112 |
-
inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality],
|
113 |
outputs=[content_and_output]
|
114 |
).then(
|
115 |
fn=save_image,
|
@@ -138,13 +145,9 @@ with gr.Blocks(css=css) as demo:
|
|
138 |
)
|
139 |
|
140 |
examples = gr.Examples(
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
['./content_images/SeaTurtle.jpg', 'Oil Painting', *optimal_settings['Oil Painting']],
|
145 |
-
['./content_images/NYCSkyline.jpg', 'Scream', *optimal_settings['Scream']]
|
146 |
-
],
|
147 |
-
inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality]
|
148 |
)
|
149 |
|
150 |
demo.queue = False
|
|
|
41 |
cached_style_features[style_name] = style_features
|
42 |
|
43 |
@spaces.GPU(duration=10)
|
44 |
+
def run(content_image, style_name, style_strength, output_quality, num_iters, optimizer, progress=gr.Progress(track_tqdm=True)):
|
45 |
yield None
|
46 |
img_size = 1024 if output_quality else 512
|
47 |
content_img, original_size = preprocess_img(content_image, img_size)
|
|
|
62 |
model=model,
|
63 |
content_image=content_img,
|
64 |
style_features=style_features,
|
65 |
+
lr=converted_lr,
|
66 |
+
iterations=num_iters,
|
67 |
+
optim_caller=getattr(torch.optim, optimizer),
|
68 |
)
|
69 |
et = time.time()
|
70 |
print('TIME TAKEN:', et-st)
|
|
|
90 |
with gr.Column(elem_id='container'):
|
91 |
content_and_output = gr.Image(label='Content', show_label=False, type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
|
92 |
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', info='Note: Adjustments automatically optimize for different styles.', value='Starry Night', type='value')
|
93 |
+
|
94 |
+
with gr.Accordion('Adjustments', open=True):
|
95 |
with gr.Group():
|
96 |
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=100, step=1, value=50)
|
97 |
|
|
|
101 |
high_button = gr.Button('High', size='sm').click(fn=lambda: set_slider(100), outputs=[style_strength_slider])
|
102 |
with gr.Group():
|
103 |
output_quality = gr.Checkbox(label='More Realistic', info='Note: If unchecked, the resulting image will have a more artistic flair.')
|
104 |
+
|
105 |
+
with gr.Accordion('Advanced Settings', open=False):
|
106 |
+
num_iters_slider = gr.Slider(label='Iterations', minimum=1, maximum=50, step=1, value=35)
|
107 |
+
optimizer_radio = gr.Radio(label='Optimizer', choices=['Adam', 'AdamW', 'LBFGS'], value='AdamW')
|
108 |
|
109 |
submit_button = gr.Button('Submit', variant='primary')
|
110 |
download_button = gr.DownloadButton(label='Download Image', visible=False)
|
|
|
116 |
|
117 |
submit_button.click(
|
118 |
fn=run,
|
119 |
+
inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality, num_iters_slider, optimizer_radio],
|
120 |
outputs=[content_and_output]
|
121 |
).then(
|
122 |
fn=save_image,
|
|
|
145 |
)
|
146 |
|
147 |
examples = gr.Examples(
|
148 |
+
label='Example',
|
149 |
+
examples=[['./content_images/Bridge.jpg', 'Starry Night', 100, False, 35, 'AdamW']],
|
150 |
+
inputs=[content_and_output, style_dropdown, style_strength_slider, output_quality, num_iters_slider, optimizer_radio]
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
|
153 |
demo.queue = False
|
inference.py
CHANGED
@@ -27,22 +27,24 @@ def inference(
|
|
27 |
style_features,
|
28 |
lr,
|
29 |
iterations=35,
|
|
|
30 |
alpha=1,
|
31 |
beta=1
|
32 |
):
|
33 |
generated_image = content_image.clone().requires_grad_(True)
|
34 |
-
optimizer =
|
35 |
|
36 |
with torch.no_grad():
|
37 |
content_features = model(content_image)
|
38 |
-
|
39 |
-
|
40 |
optimizer.zero_grad()
|
41 |
-
|
42 |
generated_features = model(generated_image)
|
43 |
total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
|
44 |
-
|
45 |
total_loss.backward()
|
46 |
-
|
|
|
|
|
|
|
47 |
|
48 |
return generated_image
|
|
|
27 |
style_features,
|
28 |
lr,
|
29 |
iterations=35,
|
30 |
+
optim_caller=optim.AdamW,
|
31 |
alpha=1,
|
32 |
beta=1
|
33 |
):
|
34 |
generated_image = content_image.clone().requires_grad_(True)
|
35 |
+
optimizer = optim_caller([generated_image], lr=lr)
|
36 |
|
37 |
with torch.no_grad():
|
38 |
content_features = model(content_image)
|
39 |
+
|
40 |
+
def closure():
|
41 |
optimizer.zero_grad()
|
|
|
42 |
generated_features = model(generated_image)
|
43 |
total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
|
|
|
44 |
total_loss.backward()
|
45 |
+
return total_loss
|
46 |
+
|
47 |
+
for _ in tqdm(range(iterations), desc='The magic is happening ✨'):
|
48 |
+
optimizer.step(closure)
|
49 |
|
50 |
return generated_image
|