Krypton / app.py
sandz7's picture
placed processor tokenizer
0963b2f
raw
history blame
2.12 kB
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate
import time
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
processor = AutoProcessor.from_pretrained(model_id)
# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
@spaces.GPU(duration=120)
def krypton(input, history):
if input["files"]:
image = input["files"][-1]["path"] if isinstance(input["files"][-1], dict) else input["files"][-1]
else:
image = None
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
if not image:
gr.Error("You need to upload an image for Krypton to work.")
return
prompt = f"user\n\n<image>\n{input['text']}\nassistant\n\n"
image = Image.open(image)
inputs = processor(prompt, images=image, return_tensors='pt').to(0, torch.float16)
# Streamer
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=False, skip_prompt=True)
# Generation kwargs
generation_kwargs = dict(
inputs=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
streamer=streamer,
max_new_tokens=1024,
do_sample=False
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
time.sleep(0.5)
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer
time.sleep(0.06)
yield generated_text_without_prompt