mehdiabruee commited on
Commit
a2621e6
1 Parent(s): 3822ccd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -10
app.py CHANGED
@@ -134,24 +134,16 @@ model = Generator(3, 32, 3, 4).cpu() # input_dim, num_filter, output_dim, num_re
134
  model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
135
  print(model)
136
  model.eval()
137
- model_2 = Generator(3, 32, 3, 4)
138
- model_2.load_state_dict(torch.load('G_B_HW4_SAVE.pt',map_location=torch.device('cpu')))
139
- model_2.eval()
140
 
141
  totensor = torchvision.transforms.ToTensor()
142
  normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
143
  topilimage = torchvision.transforms.ToPILImage()
144
 
145
- def predict(input_1, input_2):
146
  im1 = normalize_fn(totensor(input_1))
147
  print(im1.shape)
148
  preds1 = model(im1.unsqueeze(0))/2 + 0.5
149
  print(preds1.shape)
150
-
151
- im2 = normalize_fn(totensor(input_2))
152
- print(im2.shape)
153
- preds2 = model(im2.unsqueeze(0))/2 + 0.5
154
- print(preds2.shape)
155
- return topilimage(preds1.squeeze(0).detach()), topilimage(preds2.squeeze(0).detach())
156
 
157
  gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()
 
134
  model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
135
  print(model)
136
  model.eval()
 
 
 
137
 
138
  totensor = torchvision.transforms.ToTensor()
139
  normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
140
  topilimage = torchvision.transforms.ToPILImage()
141
 
142
+ def predict(input_1):
143
  im1 = normalize_fn(totensor(input_1))
144
  print(im1.shape)
145
  preds1 = model(im1.unsqueeze(0))/2 + 0.5
146
  print(preds1.shape)
147
+ return topilimage(preds1.squeeze(0).detach())
 
 
 
 
 
148
 
149
  gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()