streamlit_qwen / app.py
lukiod's picture
Add application file
7901fac
raw
history blame
2.39 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from PIL import Image
from byaldi import RAGMultiModalModel
from qwen_vl_utils import process_vision_info
# Model and processor names
RAG_MODEL = "vidore/colpali"
QWN_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
QWN_PROCESSOR = "Qwen/Qwen2-VL-2B-Instruct"
@st.cache_resource
def load_models():
RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL)
model = AutoModelForCausalLM.from_pretrained(
QWN_MODEL,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).cuda().eval()
processor = AutoProcessor.from_pretrained(QWN_PROCESSOR, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(QWN_PROCESSOR, trust_remote_code=True)
return RAG, model, processor, tokenizer
RAG, model, processor, tokenizer = load_models()
def document_rag(text_query, image):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": text_query},
],
}
]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=50)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = tokenizer.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
st.title("Document Processor")
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
text_query = st.text_input("Enter your text query")
if uploaded_file is not None and text_query:
image = Image.open(uploaded_file)
if st.button("Process Document"):
with st.spinner("Processing..."):
result = document_rag(text_query, image)
st.success("Processing complete!")
st.write("Result:", result)