root commited on
Commit
1f8bf61
1 Parent(s): 093dbc9
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from threading import Thread
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import AutoModel, AutoProcessor
7
+ from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
8
+
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+
11
+ model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-halfB", trust_remote_code=True).to(device)
12
+ processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-halfB", trust_remote_code=True)
13
+
14
+ class StopOnTokens(StoppingCriteria):
15
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16
+ stop_ids = [151645]
17
+ for stop_id in stop_ids:
18
+ if input_ids[0][-1] == stop_id:
19
+ return True
20
+ return False
21
+
22
+ @torch.no_grad()
23
+ def response(message, history, image):
24
+ stop = StopOnTokens()
25
+
26
+ messages = [{"role": "system", "content": "You are a helpful assistant."}]
27
+
28
+ for user_msg, assistant_msg in history:
29
+ messages.append({"role": "user", "content": user_msg})
30
+ messages.append({"role": "assistant", "content": assistant_msg})
31
+
32
+ if len(messages) == 1:
33
+ message = f" <image>{message}"
34
+
35
+ messages.append({"role": "user", "content": message})
36
+
37
+ model_inputs = processor.tokenizer.apply_chat_template(
38
+ messages,
39
+ add_generation_prompt=True,
40
+ return_tensors="pt"
41
+ )
42
+
43
+ image = (
44
+ processor.feature_extractor(image)
45
+ .unsqueeze(0)
46
+ )
47
+
48
+ attention_mask = torch.ones(
49
+ 1, model_inputs.shape[1] + processor.num_image_latents - 1
50
+ )
51
+
52
+ model_inputs = {
53
+ "input_ids": model_inputs,
54
+ "images": image,
55
+ "attention_mask": attention_mask
56
+ }
57
+
58
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
59
+
60
+ streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
61
+ generate_kwargs = dict(
62
+ model_inputs,
63
+ streamer=streamer,
64
+ max_new_tokens=1024,
65
+ stopping_criteria=StoppingCriteriaList([stop])
66
+ )
67
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
68
+ t.start()
69
+
70
+ history.append([message, ""])
71
+ partial_response = ""
72
+ for new_token in streamer:
73
+ partial_response += new_token
74
+ history[-1][1] = partial_response
75
+ yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True)
76
+
77
+
78
+ with gr.Blocks() as demo:
79
+ with gr.Row():
80
+ image = gr.Image(type="pil")
81
+
82
+ with gr.Column():
83
+ chat = gr.Chatbot(show_label=False)
84
+ message = gr.Textbox(interactive=True, show_label=False, container=False)
85
+
86
+ with gr.Row():
87
+ gr.ClearButton([chat, message])
88
+ stop = gr.Button(value="Stop", variant="stop", visible=False)
89
+ submit = gr.Button(value="Submit", variant="primary")
90
+
91
+ with gr.Row():
92
+ gr.Examples(
93
+ [
94
+ ["images/interior.jpg", "Describe the image accurately."],
95
+ ["images/cat.jpg", "Describe the image in three sentences."],
96
+ ["images/child.jpg", "Describe the image in one sentence."],
97
+ ],
98
+ [image, message],
99
+ label="Captioning"
100
+ )
101
+ gr.Examples(
102
+ [
103
+ ["images/scream.jpg", "What is the main emotion of this image?"],
104
+ ["images/louvre.jpg", "Where is this landmark located?"],
105
+ ["images/three_people.jpg", "What are these people doing?"]
106
+ ],
107
+ [image, message],
108
+ label="VQA"
109
+ )
110
+
111
+ response_handler = (
112
+ response,
113
+ [message, chat, image],
114
+ [chat, submit, stop]
115
+ )
116
+ postresponse_handler = (
117
+ lambda: (gr.Button(visible=False), gr.Button(visible=True)),
118
+ None,
119
+ [stop, submit]
120
+ )
121
+
122
+ event1 = message.submit(*response_handler)
123
+ event1.then(*postresponse_handler)
124
+ event2 = submit.click(*response_handler)
125
+ event2.then(*postresponse_handler)
126
+
127
+ stop.click(None, None, None, cancels=[event1, event2])
128
+
129
+ demo.queue()
130
+ demo.launch()
images/cat.jpg ADDED
images/child.jpg ADDED
images/interior.jpg ADDED
images/louvre.jpg ADDED
images/scream.jpg ADDED
images/three_people.jpg ADDED