prithivMLmods's picture
Update app.py
68abcc8 verified
raw
history blame
7.49 kB
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from PIL import Image
import uuid
import io
# Text-only model setup
DESCRIPTION = """
# GWQ PREV
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "prithivMLmods/GWQ2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()
# Multimodal model setup
MULTIMODAL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
multimodal_model = Qwen2VLForConditionalGeneration.from_pretrained(
MULTIMODAL_MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16
).to("cuda").eval()
multimodal_processor = AutoProcessor.from_pretrained(MULTIMODAL_MODEL_ID, trust_remote_code=True)
image_extensions = Image.registered_extensions()
def identify_and_save_image(blob_path):
"""Identifies if the blob is an image and saves it accordingly."""
try:
with open(blob_path, 'rb') as file:
blob_content = file.read()
# Try to identify if it's an image
try:
Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
extension = ".png" # Default to PNG for saving
media_type = "image"
except (IOError, SyntaxError):
raise ValueError("Unsupported media type. Please upload an image.")
# Create a unique filename
filename = f"temp_{uuid.uuid4()}_media{extension}"
with open(filename, "wb") as f:
f.write(blob_content)
return filename, media_type
except FileNotFoundError:
raise ValueError(f"The file {blob_path} was not found.")
except Exception as e:
raise ValueError(f"An error occurred while processing the file: {e}")
@spaces.GPU()
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
files: list = None,
) -> Iterator[str]:
if files and len(files) > 0:
# Multimodal input (image only)
media_path = files[0]
if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
media_type = "image"
else:
try:
media_path, media_type = identify_and_save_image(media_path)
except Exception as e:
raise ValueError("Unsupported media type. Please upload an image.")
# Load the image
image = Image.open(media_path).convert("RGB")
# Prepare the input for the multimodal model
messages = [
{
"role": "user",
"content": [
{"image": media_path}, # Pass the image path
{"text": message}, # Pass the text prompt
],
}
]
# Process the input
inputs = multimodal_processor(
messages,
return_tensors="pt",
padding=True,
).to("cuda")
# Stream the output
streamer = TextIteratorStreamer(
multimodal_processor, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
# Start the generation in a separate thread
thread = Thread(target=multimodal_model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the output token by token
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
else:
# Text-only input
# Ensure the chat history alternates between user and assistant roles
conversation = []
for i, entry in enumerate(chat_history):
if i % 2 == 0:
conversation.append({"role": "user", "content": entry["content"]})
else:
conversation.append({"role": "assistant", "content": entry["content"]})
conversation.append({"role": "user", "content": message})
# Apply the chat template
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
# Stream the output
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
description=DESCRIPTION,
css_paths="style.css",
fill_height=True,
multimodal=True,
textbox=gr.MultimodalTextbox(),
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)