File size: 1,448 Bytes
9aafde8
9aade84
9aafde8
 
dbd7df0
 
9aafde8
 
 
 
 
 
 
 
 
 
 
9aade84
609a852
 
9aade84
 
 
 
26177f6
9aade84
 
9aafde8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import gradio as gr
import torch

feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")

def classify_image(image):
   pixel_values = feature_extractor(image, return_tensors="pt")
   
   with torch.no_grad():
     outputs = model(pixel_values)
     logits = outputs.logits
     
   predicted_class = model.config.id2label[logits.argmax(-1).item()]
   
   return predicted_class

image = gr.inputs.Image(type="pil")
label = gr.outputs.Label(num_top_classes=3)
title = "Document Image Transformer"
description = "Gradio Demo for DiT, the Document Image Transformer pre-trained on IIT-CDIP, a dataset that includes 42 million document images and fine-tuned on RVL-CDIP, a dataset consisting of 400,000 grayscale images in 16 classes, with 25,000 images per class. To use it, simply add your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/microsoft/dit-base-finetuned-rvlcdip' target='_blank'>Huggingface Model</a></p>"
examples = [
    ["coca_cola_advertisement.png"]
]

gr.Interface(fn=classify_image, inputs=image, outputs=label, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)