File size: 2,593 Bytes
08a6c8d
5f6343c
08a6c8d
f7151f4
08a6c8d
 
 
 
f7151f4
08a6c8d
 
 
 
 
 
 
 
 
 
 
 
f7151f4
 
 
08a6c8d
 
 
 
 
f7151f4
08a6c8d
 
659f477
 
 
 
 
 
08a6c8d
 
 
f7151f4
659f477
 
 
 
 
 
f7151f4
08a6c8d
f7151f4
08a6c8d
 
659f477
08a6c8d
659f477
08a6c8d
f7151f4
08a6c8d
 
f7151f4
08a6c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7151f4
 
 
08a6c8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from threading import Thread
from typing import Dict

import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer


TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.1</center></h1>"

DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/hiyouga/PaliGemma-3B-Chat-v0.1' target='_blank'>our model page</a> for details.</center></h3>"

CSS = """
.duplicate-button {
  margin: auto !important;
  color: white !important;
  background: black !important;
  border-radius: 100vh !important;
}
"""


model_id = "hiyouga/PaliGemma-3B-Chat-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")


@spaces.GPU
def stream_chat(message: Dict[str, str], history: list):
    # {'text': 'what is this', 'files': ['image-xxx.jpg']}


    image = Image.open(message["files"][0])
    pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]

    conversation = []
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    conversation.append({"role": "user", "content": message["text"]})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    image_token_id = tokenizer.convert_tokens_to_ids("<image>")
    image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
    input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        pixel_values=pixel_values
        streamer=streamer,
        max_new_tokens=256,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    output = ""
    for new_token in streamer:
        output += new_token
        yield output


chatbot = gr.Chatbot(height=450)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        multimodal=True,
        chatbot=chatbot,
        fill_height=True,
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()