nanoLLaVA / app.py
qnguyen3's picture
Update app.py
58c422e verified
raw history blame
No virus
5.81 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
from modeling_llava_qwen2 import LlavaQwen2ForCausalLM
from threading import Thread
import re
import time
from PIL import Image
import torch
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
torch.set_default_device('cuda')
tokenizer = AutoTokenizer.from_pretrained(
'qnguyen3/nanoLLaVA',
trust_remote_code=True)
model = LlavaQwen2ForCausalLM.from_pretrained(
'qnguyen3/nanoLLaVA',
torch_dtype=torch.float16,
trust_remote_code=True)
model.to("cuda:0")
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
@spaces.GPU
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
if torch.equal(truncated_output_ids, keyword_id):
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
@spaces.GPU
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)
@spaces.GPU
def bot_streaming(message, history):
messages = []
if message["files"]:
image = message["files"][-1]["path"]
else:
for i, hist in enumerate(history):
if type(hist[0])==tuple:
image = hist[0][0]
image_turn = i
if len(history) > 0 and image is not None:
messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
messages.append({"role": "assistant", "content": history[1][1] })
for human, assistant in history[2:]:
messages.append({"role": "user", "content": human })
messages.append({"role": "assistant", "content": assistant })
messages.append({"role": "user", "content": message['text']})
elif len(history) > 0 and image is None:
for human, assistant in history:
messages.append({"role": "user", "content": human })
messages.append({"role": "assistant", "content": assistant })
messages.append({"role": "user", "content": message['text']})
elif len(history) == 0 and image is not None:
messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
elif len(history) == 0 and image is None:
messages.append({"role": "user", "content": message['text'] })
# if image is None:
# gr.Error("You need to upload an image for LLaVA to work.")
image = Image.open(image).convert("RGB")
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0).to("cuda:0")
stop_str = '<|im_end|>'
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
image_tensor = model.process_images([image], model.config).to("cuda:0")
generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100, stopping_criteria=[stopping_criteria])
generated_text = ""
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
buffer = ""
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer[len(text_prompt):]
time.sleep(0.04)
yield generated_text_without_prompt
demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
{"text": "How to make this pastry?", "files":["./baklava.png"]}],
description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo (more specifically, the [Mistral-7B variant](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True)