Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,29 @@ model_name = "runaksh/chest_xray_pneumonia_detection"
|
|
9 |
model = ViTForImageClassification.from_pretrained(model_name)
|
10 |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
11 |
|
12 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Convert the PIL Image to a format compatible with the feature extractor
|
14 |
image = np.array(image)
|
15 |
# Preprocess the image and prepare it for the model
|
@@ -47,8 +69,8 @@ def make_block(dem):
|
|
47 |
in_prompt_2 = gr.Image()
|
48 |
out_response_2 = gr.Label()
|
49 |
b2 = gr.Button("Enter")
|
50 |
-
b1.click(
|
51 |
-
b2.click(
|
52 |
|
53 |
if __name__ == '__main__':
|
54 |
|
|
|
9 |
model = ViTForImageClassification.from_pretrained(model_name)
|
10 |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
11 |
|
12 |
+
def classify_image_pneumonia(image):
|
13 |
+
# Convert the PIL Image to a format compatible with the feature extractor
|
14 |
+
image = np.array(image)
|
15 |
+
# Preprocess the image and prepare it for the model
|
16 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
17 |
+
# Make prediction
|
18 |
+
with torch.no_grad():
|
19 |
+
outputs = model(**inputs)
|
20 |
+
logits = outputs.logits
|
21 |
+
# Retrieve the highest probability class label index
|
22 |
+
predicted_class_idx = logits.argmax(-1).item()
|
23 |
+
# Define a manual mapping of label indices to human-readable labels
|
24 |
+
index_to_label = {
|
25 |
+
0: "NORMAL",
|
26 |
+
1: "PNEUMONIA"
|
27 |
+
}
|
28 |
+
|
29 |
+
# Convert the index to the model's class label
|
30 |
+
label = index_to_label.get(predicted_class_idx, "Unknown Label")
|
31 |
+
|
32 |
+
return label
|
33 |
+
|
34 |
+
def classify_image_tuberculosis(image):
|
35 |
# Convert the PIL Image to a format compatible with the feature extractor
|
36 |
image = np.array(image)
|
37 |
# Preprocess the image and prepare it for the model
|
|
|
69 |
in_prompt_2 = gr.Image()
|
70 |
out_response_2 = gr.Label()
|
71 |
b2 = gr.Button("Enter")
|
72 |
+
b1.click(classify_image_pneumonia, inputs=in_prompt_1, outputs=out_response_1)
|
73 |
+
b2.click(classify_image_tuberculosis, inputs=in_prompt_2, outputs=out_response_2)
|
74 |
|
75 |
if __name__ == '__main__':
|
76 |
|