jonathanjordan21's picture
Create app.py
1909f20
raw
history blame
No virus
2.03 kB
import gradio as gr
import re
from PIL import Image
from io import BytesIO
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
def predict(inp):
# Load processor
processor = DonutProcessor.from_pretrained("jonathanjordan21/donut_fine_tuning_food_composition_id")
# Load model
model = VisionEncoderDecoderModel.from_pretrained("jonathanjordan21/donut_fine_tuning_food_composition_id")
# Define Json Parser
def get_komposisi(image_path, image=None):
image = Image.open(image_path).convert('RGB') if image== None else image.convert('RGB')
task_prompt = "<s_kmpsi>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence1 = processor.batch_decode(outputs.sequences)[0]
sequence2 = sequence1.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence3 = re.sub(r"<.*?>", "", sequence2, count=1).strip() # remove first task start token
return processor.token2json(sequence3)
#Generate Output
out = get_komposisi("", Image.open(BytesIO(image)))
return out
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs="json",
examples=["lion.jpg", "cheetah.jpg"]).launch()