File size: 2,118 Bytes
168da77
 
 
 
 
06de88f
bfd4b05
5ae7f9c
d364219
 
 
 
 
 
 
 
 
168da77
 
 
 
0963b2f
168da77
 
d364219
0963b2f
 
5659ce7
69eca47
0963b2f
5659ce7
0963b2f
5659ce7
0963b2f
5659ce7
0963b2f
5659ce7
5ae7f9c
0963b2f
 
 
 
 
5ae7f9c
0963b2f
 
daa8caf
0963b2f
 
daa8caf
 
0963b2f
 
153de5a
2c4c1d7
 
daa8caf
0963b2f
daa8caf
 
0963b2f
5ae7f9c
 
 
 
 
 
0963b2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate
import time

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''

model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

processor = AutoProcessor.from_pretrained(model_id)

# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id

@spaces.GPU(duration=120)
def krypton(input, history):
    if input["files"]:
        image = input["files"][-1]["path"] if isinstance(input["files"][-1], dict) else input["files"][-1]
    else:
        image = None
        for hist in history:
            if isinstance(hist[0], tuple):
                image = hist[0][0]
    
    if not image:
        gr.Error("You need to upload an image for Krypton to work.")
        return
    
    prompt = f"user\n\n<image>\n{input['text']}\nassistant\n\n"
    image = Image.open(image)
    inputs = processor(prompt, images=image, return_tensors='pt').to(0, torch.float16)
    
    # Streamer
    streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=False, skip_prompt=True)
    
    # Generation kwargs
    generation_kwargs = dict(
        inputs=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=False
    )
    
    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    buffer = ""
    time.sleep(0.5)
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        time.sleep(0.06)
        yield generated_text_without_prompt