5258-vikram commited on
Commit
8157e89
·
verified ·
1 Parent(s): 78eeb48

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +20 -0
  2. main.py +163 -0
  3. model_cnn_final.pth +3 -0
  4. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.8
3
+
4
+ # Set the working directory to /flask_app
5
+ WORKDIR /flask_app
6
+
7
+ # Copy the current directory contents into the container at /flask_app
8
+ COPY . /flask_app
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Make port 5000 available to the world outside this container
14
+ EXPOSE 5000
15
+
16
+ # Define environment variable
17
+ ENV NAME World
18
+
19
+ # Run app.py when the container launches
20
+ CMD ["python3", "main.py"]
main.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask,request, send_file
2
+ import os
3
+ import io
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision.datasets import ImageFolder
9
+ import torchvision.transforms as transforms
10
+ from PIL import Image
11
+ import matplotlib.pyplot as plt
12
+ from datetime import datetime
13
+
14
+ app = Flask(__name__)
15
+
16
+ @app.route('/', methods=['GET'])
17
+ def dummy_get():
18
+ return "Welcome to Flask App"
19
+
20
+ @app.route('/upload', methods=['POST'])
21
+ def upload_file():
22
+ class CNN_Stage3(nn.Module):
23
+ def __init__(self, in_channels, out_channels):
24
+ super(CNN_Stage3, self).__init__()
25
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, dilation=2, padding=1)
26
+ self.relu = nn.ReLU()
27
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=1)
28
+
29
+ def forward(self, x):
30
+ x = self.conv1(x)
31
+ x = self.relu(x)
32
+ x = self.pool(x)
33
+ x = self.relu(x)
34
+ return x
35
+
36
+ class CNN_Stage1(nn.Module):
37
+ def __init__(self, in_channels, out_channels):
38
+ super(CNN_Stage1, self).__init__()
39
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=1)
40
+ self.relu = nn.ReLU()
41
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=1)
42
+
43
+ def forward(self, x):
44
+ x = self.conv1(x)
45
+ x = self.relu(x)
46
+ x = self.pool(x)
47
+ x = self.relu(x)
48
+ return x
49
+
50
+ class CNN(nn.Module):
51
+ def __init__(self, num_classes):
52
+ super(CNN, self).__init__()
53
+ self.cnn_stage_1 = CNN_Stage1(3, 6)
54
+ self.cnn_stage_2 = CNN_Stage1(6, 12)
55
+ self.cnn_stage_3 = CNN_Stage3(12, 24)
56
+ self.cnn_stage_4 = CNN_Stage1(24, 48)
57
+ self.cnn_stage_5 = CNN_Stage1(48, 96)
58
+ self.fc1 = nn.Linear(96 * 3 * 3, 64)
59
+ self.fc2 = nn.Linear(64, num_classes)
60
+ self.relu = nn.ReLU()
61
+
62
+ def forward(self, x):
63
+ x = self.cnn_stage_1(x)
64
+ x = self.cnn_stage_2(x)
65
+ x = self.cnn_stage_3(x)
66
+ x = self.cnn_stage_4(x)
67
+ x = self.cnn_stage_5(x)
68
+ x = x.view(x.size(0), -1)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ return x
73
+
74
+ class CustomDataset(Dataset):
75
+ def __init__(self, root_dir, transform=None):
76
+ self.dataset = ImageFolder(root_dir, transform=transform)
77
+ self.classes = self.dataset.classes
78
+
79
+ def __len__(self):
80
+ return len(self.dataset)
81
+
82
+ def __getitem__(self, idx):
83
+ image, label = self.dataset[idx]
84
+ return image, label
85
+
86
+ # Example usage:
87
+ dataset_path = 'aug_data'
88
+
89
+
90
+ transform = transforms.Compose([
91
+ transforms.Resize((22, 22)),
92
+ transforms.ToTensor(),
93
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
+ ])
95
+
96
+ custom_dataset = CustomDataset(root_dir=dataset_path, transform=transform)
97
+
98
+ num_classes = len(custom_dataset.classes)
99
+ batch_size = 32
100
+
101
+ data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)
102
+
103
+ model = CNN(num_classes)
104
+
105
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
106
+
107
+ # Load the model
108
+
109
+ checkpoint = torch.load("model_cnn_final.pth")
110
+ model.load_state_dict(checkpoint['model_state_dict'])
111
+
112
+ # Assuming optimizer was saved in the checkpoint
113
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
114
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
115
+ epoch = checkpoint['epoch']
116
+ loss = checkpoint['loss']
117
+
118
+
119
+ # Print model's parameter names
120
+ for name, param in model.named_parameters():
121
+ print(name)
122
+
123
+ if 'file' not in request.files:
124
+ return 'No file part'
125
+
126
+ file = request.files['file']
127
+
128
+ # Generate a unique filename using a timestamp
129
+ timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
130
+ unique_filename = f"{timestamp}_{file.filename}"
131
+
132
+
133
+ file.save(f'uploads/{unique_filename}')
134
+
135
+
136
+ input_image = Image.open(f'uploads/{unique_filename}')
137
+ input_tensor = transform(input_image)
138
+ input_batch = input_tensor.unsqueeze(0)
139
+
140
+ # Use the loaded model to make predictions
141
+ with torch.no_grad():
142
+ output = model(input_batch)
143
+
144
+
145
+ # If the user does not select a file, the browser submits an empty file without a filename
146
+ if file.filename == '':
147
+ return 'No selected file'
148
+ else:
149
+ # Interpret the predictions
150
+ class_names = ['cancer', 'no- cancer']
151
+ _, predicted_class = torch.max(output, 1)
152
+ predicted_label = class_names[predicted_class.item()]
153
+ print(f'The image is classified as: {predicted_label}')
154
+
155
+ plt.imshow(input_image)
156
+ # print(f'The image is classified as: {predicted_label}')
157
+ return f'The image is classified as: {predicted_label}'
158
+
159
+
160
+
161
+ if __name__ == "__main__":
162
+ app.run(host='0.0.0.0',debug=True, port=5000)
163
+
model_cnn_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea3e34197acc62b9ff2504483863b8cb6ace980cf35ec3bf8654dabe6fbbcfe2
3
+ size 2526350
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ flask
2
+ torch
3
+ torchvision
4
+ matplotlib
5
+ datetime
6
+ pillow
7
+ uvicorn