pvyas96's picture
Update app.py
cbad06d verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from threading import Thread
import re
import time
from PIL import Image
import torch
# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-Instruct",
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cpu" else None # Automatically maps to CPU if no GPU
).to(device)
# Inference function
def model_inference(
input_dict, history, decoding_strategy, temperature, max_new_tokens,
repetition_penalty, top_p
):
text = input_dict["text"]
if len(input_dict["files"]) > 1:
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [Image.open(input_dict["files"][0]).convert("RGB")]
else:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
resulting_messages = [
{
"role": "user",
"content": [{"type": "image"} for _ in range(len(images))] + [
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
assert decoding_strategy in ["Greedy", "Top P Sampling"]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
# Stream generation
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
generated_text = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
thread.join()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
# Gradio interface
demo = gr.ChatInterface(
fn=model_inference,
title="Geoscience AI Interpreter",
description=(
"This app interprets thin sections, seismic images, etc. "
"Upload an image and a text query. Works best with single-turn conversations. "
"Clear the conversation after a single turn."
),
textbox=gr.MultimodalTextbox(
label="Query Input", file_types=["image"], file_count="multiple"
),
stop_btn="Stop Generation",
multimodal=True,
additional_inputs=[
gr.Radio(
["Top P Sampling", "Greedy"],
value="Greedy",
label="Decoding strategy",
info="Higher values are equivalent to sampling more low-probability tokens.",
),
gr.Slider(
minimum=0.0,
maximum=5.0,
value=0.4,
step=0.1,
interactive=True,
label="Sampling temperature",
info="Higher values produce more diverse outputs.",
),
gr.Slider(
minimum=8,
maximum=1024,
value=512,
step=1,
interactive=True,
label="Maximum number of new tokens to generate",
),
gr.Slider(
minimum=0.01,
maximum=5.0,
value=1.2,
step=0.01,
interactive=True,
label="Repetition penalty",
info="1.0 is equivalent to no penalty.",
),
gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.8,
step=0.01,
interactive=True,
label="Top P",
info="Higher values are equivalent to sampling more low-probability tokens.",
),
],
cache_examples=False,
)
# Launch Gradio app
demo.launch(debug=True)