File size: 1,374 Bytes
a074951
 
 
b10bfcc
a074951
 
 
 
 
 
b10bfcc
a074951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b10bfcc
a074951
b10bfcc
 
 
 
 
 
 
 
a074951
b10bfcc
a074951
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
from PIL import Image
from transformers import ViTImageProcessor, ViTForImageClassification
import torch


processor = ViTImageProcessor.from_pretrained('Rageshhf/fine-tuned-model')

id2label = {0: 'Mild_Demented', 1: 'Moderate_Demented', 2: 'Non_Demented', 3: 'Very_Mild_Demented'}
label2id = {'Mild_Demented': 0, 'Moderate_Demented': 1, 'Non_Demented': 2, 'Very_Mild_Demented': 3}
labels = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']

model = ViTForImageClassification.from_pretrained(
    'Rageshhf/fine-tuned-model',
    num_labels=4,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True)

title = "Medi- classifier"
description = """Trained to classify disease based on image data."""



def predict(image):
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)

    logits = outputs.logits
    prediction = torch.nn.functional.softmax(logits, dim=1)
    probabilities = prediction[0].tolist()

    output = {}
    for i, prob in enumerate(probabilities):
        output[labels[i]] = prob

    return output

demo = gr.Interface(fn=predict, inputs="image", outputs=gr.Label(num_top_classes=3), title=title, examples=["examples/image_1.png", "examples/image_2.png", "examples/image_3.png"],
    description=description,).launch()

# demo.launch(debug=True)