Rageshhf commited on
Commit
a074951
1 Parent(s): efc49c4

initial commit

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from transformers import ViTImageProcessor, ViTForImageClassification
4
+
5
+
6
+
7
+ processor = ViTImageProcessor.from_pretrained('Rageshhf/fine-tuned-model')
8
+
9
+ id2label = {0: 'Mild_Demented', 1: 'Moderate_Demented', 2: 'Non_Demented', 3: 'Very_Mild_Demented'}
10
+ label2id = {'Mild_Demented': 0, 'Moderate_Demented': 1, 'Non_Demented': 2, 'Very_Mild_Demented': 3}
11
+
12
+
13
+ model = ViTForImageClassification.from_pretrained(
14
+ 'Rageshhf/fine-tuned-model',
15
+ num_labels=4,
16
+ id2label=id2label,
17
+ label2id=label2id,
18
+ ignore_mismatched_sizes=True)
19
+
20
+ title = "Medi- classifier"
21
+ description = """Trained to classify disease based on image data."""
22
+
23
+
24
+
25
+ def predict(image):
26
+
27
+ inputs = processor(images=image, return_tensors="pt")
28
+ outputs = model(**inputs)
29
+ logits = outputs.logits
30
+ # model predicts one of the 1000 ImageNet classes
31
+ predicted_class_idx = logits.argmax(-1).item()
32
+ return(model.config.id2label[predicted_class_idx])
33
+
34
+ demo = gr.Interface(fn=predict, inputs="image", outputs="text", title=title,
35
+ description=description,).launch()
36
+
37
+ # demo.launch(debug=True)