jonathanjordan21's picture
Update app.py
4d531d5
raw
history blame
No virus
2.04 kB
import gradio as gr
import re
from PIL import Image
from io import BytesIO
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
def predict(inp):
# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor
processor = DonutProcessor.from_pretrained("jonathanjordan21/donut_fine_tuning_food_composition_id")
# Load model
model = VisionEncoderDecoderModel.from_pretrained("jonathanjordan21/donut_fine_tuning_food_composition_id")
# Define Json Parser
def get_komposisi(image_path, image=None):
image = Image.open(image_path).convert('RGB') if image== None else image.convert('RGB')
task_prompt = "<s_kmpsi>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence1 = processor.batch_decode(outputs.sequences)[0]
sequence2 = sequence1.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence3 = re.sub(r"<.*?>", "", sequence2, count=1).strip() # remove first task start token
return processor.token2json(sequence3)
#Generate Output
out = get_komposisi("", inp)
return out
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs="json").launch()