Robust-R1 / app.py
WhateverBlue's picture
Update app.py
470e4ba verified
import gradio as gr
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import os
# Model path configuration - can be loaded from environment variable or default path
MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1")
# Global variables to store model and processor
model = None
processor = None
def load_model():
"""Load model and processor"""
global model, processor
if model is None or processor is None:
print(f"Loading model: {MODEL_PATH}")
# Load model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Load processor
processor = AutoProcessor.from_pretrained(MODEL_PATH)
print("Model loaded successfully!")
return model, processor
def inference(image, question, max_new_tokens=1024, temperature=0.7):
"""Perform inference"""
try:
# Ensure model is loaded
model, processor = load_model()
# Validate multimodal inputs
if image is None:
return "⚠️ Error: Please upload an image. This is a multimodal model that requires both an image and text input."
if not question or question.strip() == "":
return "⚠️ Error: Please enter your question. This is a multimodal model that requires both an image and text input."
# Build multimodal messages (image + text)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image, # Image input
},
{"type": "text", "text": question}, # Text input
],
}
]
# Prepare inputs
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the device where the model is located
device = next(model.parameters()).device
inputs = inputs.to(device)
# Generate response
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True if temperature > 0 else False,
)
# Decode output
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0]
except Exception as e:
return f"An error occurred: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Robust-R1", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## Citation
The following is a BibTeX reference:
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="📸 Upload Image (Required)",
height=400,
info="Upload an image that you want to ask questions about"
)
question_input = gr.Textbox(
label="💬 Your Question (Required)",
placeholder="e.g., Describe the content of this image",
lines=3,
info="Enter your question about the uploaded image"
)
with gr.Row():
max_tokens = gr.Slider(
minimum=64,
maximum=2048,
value=512,
step=64,
label="Max Generation Length"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
)
submit_btn = gr.Button("Submit", variant="primary", size="lg")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column(scale=1):
output = gr.Textbox(
label="Model Response",
lines=15,
interactive=False
)
# Examples
gr.Examples(
examples=[
["What is the name of the Garage?\n0. polo\n1. imam\n2. leke\n3. akd\nFirst output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags, and thenoutput what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags, then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END>tags, and then sunmmarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END>tags,provides the user with the answer briefly in<ANSWER> <ANSWER_END>.i.e., <TYPE> degradation type here <TYPE_END>\n<INFLUENCE> influence here<INFLUENCE_END>\n<REASONING> reasoning process here<REASONING_END>\n<CONCLUSION>summary here<CONCLUSION_END>\n<ANSWER>final answer<ANSWER_END>."],
],
inputs=[question_input],
label="Example Questions"
)
# Bind events
submit_btn.click(
fn=inference,
inputs=[image_input, question_input, max_tokens, temperature],
outputs=output
)
clear_btn.click(
fn=lambda: (None, "", 512, 0.7, ""),
outputs=[image_input, question_input, max_tokens, temperature, output]
)
# Show message when page loads
demo.load(
fn=lambda: "Model is loading, please wait...",
outputs=output
)
if __name__ == "__main__":
# When running in Space, Gradio will automatically handle the port
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)