File size: 2,307 Bytes
3da4879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877a841
3da4879
 
 
 
 
 
ae753cc
3da4879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import os
from peft import PeftModel
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

model_name = 'google/vit-large-patch16-224'
adapter = 'monsoon-nlp/eyegazer-vit-binary'

image_processor = AutoImageProcessor.from_pretrained(model_name)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
    [
        RandomResizedCrop(image_processor.size["height"]),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        Resize(image_processor.size["height"]),
        CenterCrop(image_processor.size["height"]),
        ToTensor(),
        normalize,
    ]
)

model = AutoModelForImageClassification.from_pretrained(
    model_name,
    ignore_mismatched_sizes=True,
    num_labels=2,
)

lora_model = PeftModel.from_pretrained(model, adapter)

def query(img):
    pimg = val_transforms(img.convert("RGB"))
    batch = pimg.unsqueeze(0)
    op = lora_model(batch)
    vals = op.logits.tolist()[0]

    if vals[0] > vals[1]:
        return "Predicted unaffected"
    else:
        return "Predicted affected to some degree"

iface = gr.Interface(
    fn=query,
    examples=[
        # os.path.join(os.path.dirname(__file__), "images/i1.png"),
        os.path.join(os.path.dirname(__file__), "images/0a09aa7356c0.png"),
        os.path.join(os.path.dirname(__file__), "images/0a4e1a29ffff.png"),
        os.path.join(os.path.dirname(__file__), "images/0c43c79e8cfb.png"),
        os.path.join(os.path.dirname(__file__), "images/0c7e82daf5a0.png"),
    ],
    inputs=[
        gr.Image(
            image_mode='RGB',
            sources=['upload', 'clipboard'],
            type='pil',
            label='Input Fundus Camera Image',
            show_label=True,
        ),
    ],
    outputs=[
        gr.Markdown(value="", label="Predicted label"),
    ],
    title="ViT retinopathy model",
    description="Diabetic retinopathy model trained on APTOS 2019 dataset; demonstration, not medical dvice",
    allow_flagging="never",
)
iface.launch()