jonathanjordan21's picture
Code change : Load model before prediction for faster inference
f130d9b
raw
history blame
No virus
1.92 kB
import gradio as gr
import re
from PIL import Image
from io import BytesIO
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor
processor = DonutProcessor.from_pretrained("jonathanjordan21/donut_fine_tuning_food_composition_id")
# Load model
model = VisionEncoderDecoderModel.from_pretrained("jonathanjordan21/donut_fine_tuning_food_composition_id")
def predict(inp):
# Define Json Parser
def get_komposisi(image_path, image=None):
image = Image.open(image_path).convert('RGB') if image== None else image.convert('RGB')
task_prompt = "<s_kmpsi>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence1 = processor.batch_decode(outputs.sequences)[0]
sequence2 = sequence1.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence3 = re.sub(r"<.*?>", "", sequence2, count=1).strip() # remove first task start token
return processor.token2json(sequence3)
#Generate Output
out = get_komposisi("", inp)
return out
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs="json").launch()