Dental_ADA / app.py
Laskari-Naveen's picture
Update app.py
3909336 verified
import streamlit as st
import torch
from transformers import AutoProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import pandas as pd
st.write("American Dental Association (ADA) claims")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = AutoProcessor.from_pretrained("Laskari-Naveen/ADA_005")
model = VisionEncoderDecoderModel.from_pretrained("Laskari-Naveen/ADA_005").to(device)
def run_prediction(image, model, processor):
pixel_values = processor(image, return_tensors="pt").pixel_values
task_prompt = "<s>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=2,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# process output
prediction = processor.batch_decode(outputs.sequences)[0]
prediction = processor.token2json(prediction)
return prediction, outputs
def split_and_expand(row):
if row['Key'] == "33_Missing_Teeth":
keys = [row['Key']]
values = row['Value'].split(';')[0]
else:
keys = [row['Key']] * len(row['Value'].split(';'))
values = row['Value'].split(';')
return pd.DataFrame({'Key': keys, 'Value': values})
uploaded_file = st.file_uploader("Choose a file")
if uploaded_file is not None:
content = uploaded_file.read()
st.image(uploaded_file)
image = Image.open(uploaded_file).convert("RGB")
prediction, output = run_prediction(image, model, processor)
st.dataframe(prediction, width=600)