artelabsuper commited on
Commit
66432b9
1 Parent(s): eba1c6b

app use models

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +46 -6
  3. test.py +3 -1
.gitignore CHANGED
@@ -2,4 +2,6 @@ venv
2
  *.pyc
3
  __pycache__
4
  sr.png
 
 
5
  test.png
2
  *.pyc
3
  __pycache__
4
  sr.png
5
+ sr2.png
6
+ sr_pred.png
7
  test.png
app.py CHANGED
@@ -1,21 +1,58 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import torchvision
 
4
  import torch
 
 
 
 
 
5
 
6
  # load model
7
- MODELS_TYPE = ["ModelA", "ModelB", "ModelC"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def predict(input_image, model_name):
10
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
11
  # transform image to torch and do preprocessing
12
- torch_image = torchvision.transforms.ToTensor()(pil_image)
13
  # model predict
14
- prediction = torch.rand(torch_image.shape)
 
 
15
  # transform torch to image
16
- predicted_pil_image = torchvision.transforms.ToPILImage()(prediction)
 
 
 
 
 
 
 
 
 
 
17
  # return correct image
18
- return predicted_pil_image
19
 
20
  iface = gr.Interface(
21
  fn=predict,
@@ -23,7 +60,10 @@ iface = gr.Interface(
23
  gr.Image(shape=(512,512)),
24
  gr.inputs.Radio(MODELS_TYPE)
25
  ],
26
- outputs=gr.Image(shape=(512,512)),
 
 
 
27
  examples=[
28
  ["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image
29
  ],
1
  import gradio as gr
2
  from PIL import Image
3
  import torchvision
4
+ from torchvision import transforms
5
  import torch
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from models.modelNetA import Generator as GA
9
+ from models.modelNetB import Generator as GB
10
+ from models.modelNetC import Generator as GC
11
 
12
  # load model
13
+ modeltype2path = {
14
+ 'ModelA': 'DTM_exp_train10%_model_a/g-best.pth',
15
+ 'ModelB': 'DTM_exp_train10%_model_b/g-best.pth',
16
+ 'ModelC': 'DTM_exp_train10%_model_c/g-best.pth',
17
+ }
18
+ DEVICE='cpu'
19
+ MODELS_TYPE = list(modeltype2path.keys())
20
+ generators = [GA(), GB(), GC()]
21
+
22
+ for i in range(len(generators)):
23
+ generators[i] = torch.nn.DataParallel(generators[i])
24
+ state_dict = torch.load(modeltype2path[MODELS_TYPE[i]], map_location=torch.device('cpu'))
25
+ generators[i].load_state_dict(state_dict)
26
+ generators[i] = generators[i].module.to(DEVICE)
27
+ generators[i].eval()
28
+
29
+ preprocess = transforms.Compose([
30
+ transforms.Grayscale(),
31
+ transforms.ToTensor()
32
+ ])
33
 
34
  def predict(input_image, model_name):
35
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
36
  # transform image to torch and do preprocessing
37
+ torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
38
  # model predict
39
+ with torch.no_grad():
40
+ output = generators[MODELS_TYPE.index(model_name)](torch_img)
41
+ sr, sr_dem_selected = output[0], output[1]
42
  # transform torch to image
43
+ sr = sr.squeeze(0).cpu()
44
+ torchvision.utils.save_image(sr, 'sr_pred.png')
45
+ sr = np.array(Image.open('sr_pred.png'))
46
+
47
+ sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
48
+ fig, ax = plt.subplots()
49
+ im = ax.imshow(sr_dem_selected, cmap='jet', vmin=0, vmax=np.max(sr_dem_selected))
50
+ plt.colorbar(im, ax=ax)
51
+ fig.canvas.draw()
52
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
53
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
54
  # return correct image
55
+ return sr, data
56
 
57
  iface = gr.Interface(
58
  fn=predict,
60
  gr.Image(shape=(512,512)),
61
  gr.inputs.Radio(MODELS_TYPE)
62
  ],
63
+ outputs=[
64
+ gr.Image(),
65
+ gr.Image()
66
+ ],
67
  examples=[
68
  ["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image
69
  ],
test.py CHANGED
@@ -35,7 +35,7 @@ generator.eval()
35
 
36
  preprocess = transforms.Compose([
37
  transforms.Grayscale(),
38
- transforms.Resize((512, 512)),
39
  transforms.ToTensor()
40
  ])
41
  input_img = Image.open('demo_imgs/fake.jpg')
@@ -47,6 +47,8 @@ sr = sr.squeeze(0).cpu()
47
 
48
  print(sr.shape)
49
  torchvision.utils.save_image(sr, 'sr.png')
 
 
50
 
51
  sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
52
  print(sr_dem_selected.shape)
35
 
36
  preprocess = transforms.Compose([
37
  transforms.Grayscale(),
38
+ # transforms.Resize((512, 512)),
39
  transforms.ToTensor()
40
  ])
41
  input_img = Image.open('demo_imgs/fake.jpg')
47
 
48
  print(sr.shape)
49
  torchvision.utils.save_image(sr, 'sr.png')
50
+ # sr = Image.fromarray(sr.squeeze(0).detach().numpy() * 255, 'L')
51
+ # sr.save('sr2.png')
52
 
53
  sr_dem_selected = sr_dem_selected.squeeze().cpu().detach().numpy()
54
  print(sr_dem_selected.shape)