nurihp commited on
Commit
b3b2e46
1 Parent(s): b64a4ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -4
app.py CHANGED
@@ -13,17 +13,48 @@ import gradio as gr
13
 
14
 
15
  # Cargamos el learner
16
- learn = load_learner('export.pkl')
 
17
 
18
  # Definimos las etiquetas de nuestro modelo
19
- labels = learn.dls.vocab
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  # Definimos una función que se encarga de llevar a cabo las predicciones
23
  def predict(img):
24
  img = PILImage.create(img)
25
- pred,pred_idx,probs = learn.predict(img)
26
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Creamos la interfaz y la lanzamos.
29
  gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Label(num_top_classes=3),examples=['color_154.jpg','color_155.jpg']).launch(share=False)
 
13
 
14
 
15
  # Cargamos el learner
16
+ #learn = load_learner('export.pkl')
17
+ model = torch.jit.load("unet.pth")
18
 
19
  # Definimos las etiquetas de nuestro modelo
20
+ #labels = learn.dls.vocab
21
+
22
+ def transform_image(image):
23
+ my_transforms = transforms.Compose([transforms.ToTensor(),
24
+ transforms.Normalize(
25
+ [0.485, 0.456, 0.406],
26
+ [0.229, 0.224, 0.225])])
27
+ image_aux = image
28
+ return my_transforms(image_aux).unsqueeze(0).to(device)
29
+
30
 
31
 
32
  # Definimos una función que se encarga de llevar a cabo las predicciones
33
  def predict(img):
34
  img = PILImage.create(img)
35
+
36
+ image = transforms.Resize((480,640))(img)
37
+ tensor = transform_image(image=image)
38
+
39
+ with torch.no_grad():
40
+ outputs = model(tensor)
41
+
42
+ outputs = torch.argmax(outputs,1)
43
+
44
+ mask = np.array(outputs.cpu())
45
+ mask[mask==0]=255 #grape
46
+ mask[mask==1]=150 #leaves
47
+ mask[mask==2]=76 #pole
48
+ mask[mask==2]=74 #pole
49
+ mask[mask==3]=29 #wood
50
+ mask[mask==3]=25 #wood
51
+
52
+ mask=np.reshape(mask,(480,640))
53
+
54
+ return Image.fromarray(mask.astype('uint8'))
55
+ #pred,pred_idx,probs = learn.predict(img)
56
+ #return {labels[i]: float(probs[i]) for i in range(len(labels))}
57
+
58
 
59
  # Creamos la interfaz y la lanzamos.
60
  gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Label(num_top_classes=3),examples=['color_154.jpg','color_155.jpg']).launch(share=False)