Caroline Mai Chan commited on
Commit
aef0baa
1 Parent(s): 8aa3f87

add new style

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +17 -7
  3. model2.pth +3 -0
.gitattributes CHANGED
@@ -26,3 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
  model.pth filter=lfs diff=lfs merge=lfs -text
 
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
  model.pth filter=lfs diff=lfs merge=lfs -text
29
+ model2.pth filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -83,18 +83,26 @@ class Generator(nn.Module):
83
 
84
  return out
85
 
86
- model = Generator(3, 1, 3)
87
- model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
88
- model.eval()
89
 
90
- def predict(input_img):
 
 
 
 
91
  input_img = Image.open(input_img)
92
  transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
93
  input_img = transform(input_img)
94
  input_img = torch.unsqueeze(input_img, 0)
95
 
 
96
  with torch.no_grad():
97
- drawing = model(input_img)[0].detach()
 
 
 
98
 
99
  drawing = transforms.ToPILImage()(drawing)
100
  return drawing
@@ -102,9 +110,11 @@ def predict(input_img):
102
  title="informative-drawings"
103
  description="Gradio Demo for line drawing generation. "
104
  # article = "<p style='text-align: center'><a href='TODO' target='_blank'>Project Page</a> | <a href='codelink' target='_blank'>Github</a></p>"
105
- examples=[['cat.png'], ['lizard.png'], ['bridge.png']]
106
 
107
 
108
- iface = gr.Interface(predict, gr.inputs.Image(type='filepath'), "image", title=title,description=description,examples=examples)
 
 
109
 
110
  iface.launch()
83
 
84
  return out
85
 
86
+ model1 = Generator(3, 1, 3)
87
+ model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
88
+ model1.eval()
89
 
90
+ model2 = Generator(3, 1, 3)
91
+ model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
92
+ model2.eval()
93
+
94
+ def predict(input_img, ver):
95
  input_img = Image.open(input_img)
96
  transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
97
  input_img = transform(input_img)
98
  input_img = torch.unsqueeze(input_img, 0)
99
 
100
+ drawing = 0
101
  with torch.no_grad():
102
+ if ver == 'style 2':
103
+ drawing = model2(input_img)[0].detach()
104
+ else:
105
+ drawing = model1(input_img)[0].detach()
106
 
107
  drawing = transforms.ToPILImage()(drawing)
108
  return drawing
110
  title="informative-drawings"
111
  description="Gradio Demo for line drawing generation. "
112
  # article = "<p style='text-align: center'><a href='TODO' target='_blank'>Project Page</a> | <a href='codelink' target='_blank'>Github</a></p>"
113
+ examples=[['cat.png', 'style 1'], ['bridge.png', 'style 1'], ['lizard.png', 'style 2'],]
114
 
115
 
116
+ iface = gr.Interface(predict, [gr.inputs.Image(type='filepath'),
117
+ gr.inputs.Radio(['style 1','style 2'], type="value", default='style 1', label='version')],
118
+ gr.outputs.Image(type="pil"), title=title,description=description,examples=examples)
119
 
120
  iface.launch()
model2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30a534781061f34e83bb9406b4335da4ff2616c95d22a585c1245aa8363e74e0
3
+ size 17173511