artelabsuper commited on
Commit
1513566
1 Parent(s): 21bf7d6

get model selection field

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -4,8 +4,9 @@ import torchvision
4
  import torch
5
 
6
  # load model
 
7
 
8
- def predict(input_image):
9
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
10
  # transform image to torch and do preprocessing
11
  torch_image = torchvision.transforms.ToTensor()(pil_image)
@@ -18,10 +19,13 @@ def predict(input_image):
18
 
19
  iface = gr.Interface(
20
  fn=predict,
21
- inputs=gr.Image(shape=(512,512)),
 
 
 
22
  outputs=gr.Image(shape=(512,512)),
23
  examples=[
24
- ["demo_imgs/fake.jpg"] # use real image
25
  ],
26
  title="DTM Estimation",
27
  description="This demo predict a DTM..."
4
  import torch
5
 
6
  # load model
7
+ MODELS_TYPE = ["ModelA", "ModelB", "ModelC"]
8
 
9
+ def predict(input_image, model_name):
10
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
11
  # transform image to torch and do preprocessing
12
  torch_image = torchvision.transforms.ToTensor()(pil_image)
19
 
20
  iface = gr.Interface(
21
  fn=predict,
22
+ inputs=[
23
+ gr.Image(shape=(512,512)),
24
+ gr.inputs.Radio(MODELS_TYPE)
25
+ ],
26
  outputs=gr.Image(shape=(512,512)),
27
  examples=[
28
+ ["demo_imgs/fake.jpg", MODELS_TYPE[0]] # use real image
29
  ],
30
  title="DTM Estimation",
31
  description="This demo predict a DTM..."