blancamartin commited on
Commit
c31bf4e
·
1 Parent(s): e40a5c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import *
2
+ from fastai.vision import models
3
+ from fastai.vision.all import *
4
+ from fastai.metrics import *
5
+ from fastai.data.all import *
6
+ from fastai.callback import *
7
+
8
+
9
+ from pathlib import Path
10
+ import random
11
+
12
+ import PIL
13
+ import torchvision.transforms as transforms
14
+
15
+ import gradio as gr
16
+
17
+
18
+ # Cargamos el learner
19
+ #learn = load_learner('export.pkl')
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model = torch.jit.load("unet.pth")
22
+ model = model.cpu()
23
+ model.eval()
24
+
25
+ # Definimos las etiquetas de nuestro modelo
26
+ #labels = learn.dls.vocab
27
+
28
+ def transform_image(image):
29
+ my_transforms = transforms.Compose([transforms.ToTensor(),
30
+ transforms.Normalize(
31
+ [0.485, 0.456, 0.406],
32
+ [0.229, 0.224, 0.225])])
33
+ image_aux = image
34
+ return my_transforms(image_aux).unsqueeze(0).to(device)
35
+
36
+
37
+
38
+ # Definimos una función que se encarga de llevar a cabo las predicciones
39
+ def predict(img):
40
+ img = PILImage.create(img)
41
+
42
+ image = transforms.Resize((480,640))(img)
43
+ tensor = transform_image(image=image)
44
+
45
+ with torch.no_grad():
46
+ outputs = model(tensor)
47
+
48
+ outputs = torch.argmax(outputs,1)
49
+
50
+ mask = np.array(outputs.cpu())
51
+ mask[mask==0]=255 #grape
52
+ mask[mask==1]=150 #leaves
53
+ mask[mask==2]=76 #pole
54
+ mask[mask==2]=74 #pole
55
+ mask[mask==3]=29 #wood
56
+ mask[mask==3]=25 #wood
57
+
58
+ mask=np.reshape(mask,(480,640))
59
+
60
+ return Image.fromarray(mask.astype('uint8'))
61
+ #pred,pred_idx,probs = learn.predict(img)
62
+ #return {labels[i]: float(probs[i]) for i in range(len(labels))}
63
+
64
+
65
+ # Creamos la interfaz y la lanzamos.
66
+ gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(128, 128)), outputs=gr.outputs.Image(),examples=['color_154.jpg','color_155.jpg']).launch(share=False)