PedroMartelleto commited on
Commit
ad937a3
1 Parent(s): a20413a

Added examples

Browse files
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from torchvision.models import resnet50, ResNet50_Weights
3
+ from torchvision import transforms
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+ @staticmethod
8
+ def create_model_from_checkpoint():
9
+ # Loads a model from a checkpoint
10
+ model = resnet50()
11
+ model.fc = nn.Linear(model.fc.in_features, 3)
12
+ model.load_state_dict(torch.load("best_model"))
13
+ model.eval()
14
+ return model
15
+
16
+ def prep_image(img):
17
+ transform = transforms.Compose([
18
+ transforms.Resize(256),
19
+ transforms.CenterCrop(224),
20
+ transforms.ToTensor()
21
+ ])
22
+
23
+ transform_normalize = transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225]
26
+ )
27
+
28
+ transformed_img = transform(img)
29
+
30
+ input = transform_normalize(transformed_img)
31
+ input = input.unsqueeze(0)
32
+ return input
33
+
34
+ model = create_model_from_checkpoint()
35
+ labels = [ "benign", "malignant", "normal" ]
36
+
37
+ def predict(img):
38
+ input = prep_image(img)
39
+ with torch.no_grad():
40
+ prediction = torch.nn.functional.softmax(model(input)[0], dim=0)
41
+ confidences = {labels[i]: float(prediction[i]) for i in range(3)}
42
+ return confidences
43
+
44
+ ui = gr.Interface(fn=predict,
45
+ inputs=gr.Image(type="pil"),
46
+ outputs=gr.Label(num_top_classes=3),
47
+ examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
48
+ ui.launch(share=True)
benign (243).png ADDED
benign (52).png ADDED
malignant (127).png ADDED
malignant (201).png ADDED
normal (101).png ADDED
normal (81).png ADDED