erika_cats commited on
Commit
eb327b2
1 Parent(s): 244f45d

feat: Add Gradio interface and integrate trained CNN model

Browse files
Files changed (2) hide show
  1. app.py +59 -4
  2. pcos_cnn_model.pth +3 -0
app.py CHANGED
@@ -1,7 +1,62 @@
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
  import gradio as gr
5
 
6
+ # Define the CNN model architecture
7
+ class CNNModel(torch.nn.Module):
8
+ def __init__(self):
9
+ super(CNNModel, self).__init__()
10
+ self.conv1 = torch.nn.Conv2d(3, 12, kernel_size=5, padding=2)
11
+ self.pool = torch.nn.MaxPool2d(2, 2)
12
+ self.conv2 = torch.nn.Conv2d(12, 8, kernel_size=5, padding=2)
13
+ self.conv3 = torch.nn.Conv2d(8, 4, kernel_size=5, padding=2)
14
+ self._to_linear = None
15
+ self.convs(torch.randn(1, 3, 224, 224))
16
+ self.fc1 = torch.nn.Linear(self._to_linear, 1)
17
 
18
+ def convs(self, x):
19
+ x = self.pool(F.relu(self.conv1(x)))
20
+ x = self.pool(F.relu(self.conv2(x)))
21
+ x = self.pool(F.relu(self.conv3(x)))
22
+ if self._to_linear is None:
23
+ self._to_linear = x.view(-1).size(0)
24
+ return x
25
+
26
+ def forward(self, x):
27
+ x = self.convs(x)
28
+ x = x.view(-1, self._to_linear)
29
+ x = self.fc1(x)
30
+ return x
31
+
32
+ # Load the model
33
+ model = CNNModel()
34
+ model.load_state_dict(torch.load('pcos_cnn_model.pth', map_location=torch.device('cpu')))
35
+ model.eval()
36
+
37
+ # Define the image transforms
38
+ transform = transforms.Compose([
39
+ transforms.Resize((224, 224)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
+ ])
43
+
44
+ # Define the prediction function
45
+ def predict(image):
46
+ image = transform(image).unsqueeze(0)
47
+ output = model(image)
48
+ prediction = torch.sigmoid(output).item()
49
+ return "Infected" if prediction == 0 else "Not Infected"
50
+
51
+ # Define the Gradio interface
52
+ interface = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.inputs.Image(type="pil"),
55
+ outputs="text",
56
+ title="PCOS Diagnosis",
57
+ description="Upload an ultrasound image to predict if it is infected or not."
58
+ )
59
+
60
+ # Launch the Gradio app
61
+ if __name__ == "__main__":
62
+ interface.launch()
pcos_cnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00c4fc9da38e203d0155b672ec7e82c961fa1ccc64a73b1493efa4c73f9f6b1b
3
+ size 32276