sleepytaco commited on
Commit
e01b77b
1 Parent(s): 0e06666

add number field for num epochs

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -3,27 +3,23 @@ import os
3
  from model.model import TextureSynthesisCNN
4
  from model.utils import convert_tensor_to_PIL_image
5
 
6
-
7
- def image_mod(image):
8
- return image.rotate(45)
9
-
10
- def synth_image(image):
11
  synthesizer = TextureSynthesisCNN(tex_exemplar_image=image)
12
- output_tensor = synthesizer.synthesize_texture(num_epochs=10)
13
  output_image = convert_tensor_to_PIL_image(output_tensor)
14
  return output_image
15
 
16
-
17
  demo = gr.Interface(
18
  fn=synth_image,
19
- inputs=[gr.Image(type="numpy")],
 
20
  outputs=[gr.Image(type="pil")],
21
  flagging_options=["blurry", "incorrect"],
22
  examples=[
23
- os.path.join(os.path.dirname(__file__), "images/blotchy_0025.png"),
24
- os.path.join(os.path.dirname(__file__), "images/blotchy_0027.png"),
25
- os.path.join(os.path.dirname(__file__), "images/cracked_0080.png"),
26
- os.path.join(os.path.dirname(__file__), "images/scenery.png"),
27
  ],
28
  )
29
 
 
3
  from model.model import TextureSynthesisCNN
4
  from model.utils import convert_tensor_to_PIL_image
5
 
6
+ def synth_image(image, epochs=10):
 
 
 
 
7
  synthesizer = TextureSynthesisCNN(tex_exemplar_image=image)
8
+ output_tensor = synthesizer.synthesize_texture(num_epochs=int(epochs))
9
  output_image = convert_tensor_to_PIL_image(output_tensor)
10
  return output_image
11
 
 
12
  demo = gr.Interface(
13
  fn=synth_image,
14
+ inputs=[gr.Image(type="numpy"),
15
+ gr.Number(label="Num epochs to optimize for", value=1, minimum=1, maximum=400)],
16
  outputs=[gr.Image(type="pil")],
17
  flagging_options=["blurry", "incorrect"],
18
  examples=[
19
+ [os.path.join(os.path.dirname(__file__), "images/blotchy_0025.png"), 10],
20
+ [os.path.join(os.path.dirname(__file__), "images/blotchy_0027.png"), 10],
21
+ [os.path.join(os.path.dirname(__file__), "images/cracked_0080.png"), 10],
22
+ [os.path.join(os.path.dirname(__file__), "images/scenery.png"), 10],
23
  ],
24
  )
25