File size: 2,118 Bytes
168da77 06de88f bfd4b05 5ae7f9c d364219 168da77 0963b2f 168da77 d364219 0963b2f 5659ce7 69eca47 0963b2f 5659ce7 0963b2f 5659ce7 0963b2f 5659ce7 0963b2f 5659ce7 5ae7f9c 0963b2f 5ae7f9c 0963b2f daa8caf 0963b2f daa8caf 0963b2f 153de5a 2c4c1d7 daa8caf 0963b2f daa8caf 0963b2f 5ae7f9c 0963b2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate
import time
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton π</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
processor = AutoProcessor.from_pretrained(model_id)
# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
@spaces.GPU(duration=120)
def krypton(input, history):
if input["files"]:
image = input["files"][-1]["path"] if isinstance(input["files"][-1], dict) else input["files"][-1]
else:
image = None
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
if not image:
gr.Error("You need to upload an image for Krypton to work.")
return
prompt = f"user\n\n<image>\n{input['text']}\nassistant\n\n"
image = Image.open(image)
inputs = processor(prompt, images=image, return_tensors='pt').to(0, torch.float16)
# Streamer
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=False, skip_prompt=True)
# Generation kwargs
generation_kwargs = dict(
inputs=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
streamer=streamer,
max_new_tokens=1024,
do_sample=False
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
time.sleep(0.5)
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer
time.sleep(0.06)
yield generated_text_without_prompt |