eyegazer-demo / app.py
Nick Doiron
dont use my example
877a841
raw
history blame
2.31 kB
import gradio as gr
import os
from peft import PeftModel
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
model_name = 'google/vit-large-patch16-224'
adapter = 'monsoon-nlp/eyegazer-vit-binary'
image_processor = AutoImageProcessor.from_pretrained(model_name)
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(image_processor.size["height"]),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
val_transforms = Compose(
[
Resize(image_processor.size["height"]),
CenterCrop(image_processor.size["height"]),
ToTensor(),
normalize,
]
)
model = AutoModelForImageClassification.from_pretrained(
model_name,
ignore_mismatched_sizes=True,
num_labels=2,
)
lora_model = PeftModel.from_pretrained(model, adapter)
def query(img):
pimg = val_transforms(img.convert("RGB"))
batch = pimg.unsqueeze(0)
op = lora_model(batch)
vals = op.logits.tolist()[0]
if vals[0] > vals[1]:
return "Predicted unaffected"
else:
return "Predicted affected to some degree"
iface = gr.Interface(
fn=query,
examples=[
# os.path.join(os.path.dirname(__file__), "images/i1.png"),
os.path.join(os.path.dirname(__file__), "images/0a09aa7356c0.png"),
os.path.join(os.path.dirname(__file__), "images/0a4e1a29ffff.png"),
os.path.join(os.path.dirname(__file__), "images/0c43c79e8cfb.png"),
os.path.join(os.path.dirname(__file__), "images/0c7e82daf5a0.png"),
],
inputs=[
gr.Image(
image_mode='RGB',
sources=['upload', 'clipboard'],
type='pil',
label='Input Fundus Camera Image',
show_label=True,
),
],
outputs=[
gr.Markdown(value="", label="Predicted label"),
],
title="ViT retinopathy model",
description="Diabetic retinopathy model trained on APTOS 2019 dataset; demonstration, not medical dvice",
allow_flagging="never",
)
iface.launch()