Inoichan commited on
Commit
1cdd8f7
1 Parent(s): ff04e06

initial commit for demo

Browse files
Files changed (4) hide show
  1. app.py +209 -0
  2. images/heron.png +0 -0
  3. images/user_icon.png +0 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import time
5
+ from threading import Thread
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from transformers import AutoProcessor, StoppingCriteria, TextIteratorStreamer
10
+
11
+
12
+ os.system(
13
+ "git clone https://github.com/turingmotors/heron.git"
14
+ "&& export CUDA_HOME=/usr/local/cuda; pip install -e heron"
15
+ )
16
+
17
+ sys.path.insert(0, "./heron")
18
+ from heron.models.git_llm.git_llama import GitLlamaForCausalLM, GitLlamaConfig
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ # This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74
24
+ class KeywordsStoppingCriteria(StoppingCriteria):
25
+ def __init__(self, keywords, tokenizer, input_ids):
26
+ self.keywords = keywords
27
+ self.keyword_ids = []
28
+ for keyword in keywords:
29
+ cur_keyword_ids = tokenizer(keyword).input_ids
30
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
31
+ cur_keyword_ids = cur_keyword_ids[1:]
32
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
33
+ self.tokenizer = tokenizer
34
+ self.start_len = input_ids.shape[1]
35
+
36
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
37
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
38
+ offset = min(output_ids.shape[1] - self.start_len, 3)
39
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
40
+ for keyword_id in self.keyword_ids:
41
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
42
+ return True
43
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
44
+ for keyword in self.keywords:
45
+ if keyword in outputs:
46
+ return True
47
+ return False
48
+
49
+
50
+ def preprocess(history, image):
51
+ text = ""
52
+ for one_history in history:
53
+ text += f"##human: {one_history[0]}\n##gpt: "
54
+ # do preprocessing
55
+ inputs = processor(
56
+ text,
57
+ image,
58
+ return_tensors="pt",
59
+ truncation=True,
60
+ )
61
+ inputs = {k: v.to(device) for k, v in inputs.items()}
62
+ return inputs
63
+
64
+
65
+ def add_text(textbox, history):
66
+ # hard text threshold
67
+ if len(textbox) > 512:
68
+ textbox = textbox[:512]
69
+ history = history + [(textbox, None)]
70
+ return "", history
71
+
72
+
73
+ title_markdown = ("""
74
+ # Heronチャットデモ
75
+
76
+ - モデル: [turing-motors/heron-chat-git-ELYZA-fast-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-ELYZA-fast-7b-v0)
77
+ - 学習コード: [Heron](https://github.com/turingmotors/heron)
78
+ """)
79
+
80
+
81
+ def stream_bot(imagebox, history):
82
+ # do preprocessing
83
+ inputs = preprocess(history, imagebox)
84
+
85
+ # streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
86
+ streamer = TextIteratorStreamer(
87
+ processor.tokenizer,
88
+ skip_prompt=True,
89
+ skip_special_tokens=True,
90
+ )
91
+ stopping_criteria = KeywordsStoppingCriteria(
92
+ [EOS_WORDS], processor.tokenizer, inputs["input_ids"]
93
+ )
94
+
95
+ inputs.update(
96
+ dict(
97
+ streamer=streamer,
98
+ max_new_tokens=max_length,
99
+ stopping_criteria=[stopping_criteria],
100
+ do_sample=True,
101
+ temperature=0.5,
102
+ no_repeat_ngram_size=2,
103
+ )
104
+ )
105
+ thread = Thread(target=model.generate, kwargs=inputs)
106
+ thread.start()
107
+
108
+ history[-1][1] = ""
109
+ for new_text in streamer:
110
+ history[-1][1] += new_text
111
+ history[-1][1] = history[-1][1].replace(EOS_WORDS, "")
112
+ time.sleep(0.05)
113
+ yield history
114
+
115
+
116
+ def regenerate(history):
117
+ history[-1] = (history[-1][0], None)
118
+ return history
119
+
120
+
121
+ def clear_history():
122
+ return [], "", None
123
+
124
+
125
+ def build_demo():
126
+ textbox = gr.Textbox(
127
+ show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False
128
+ )
129
+ with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo:
130
+ gr.Markdown(title_markdown)
131
+ with gr.Row():
132
+ with gr.Column(scale=3):
133
+ imagebox = gr.Image(type="pil")
134
+
135
+ gr.Examples(
136
+ examples=[
137
+ [
138
+ "./images/heron.png",
139
+ "What is this image?",
140
+ ],
141
+ ],
142
+ inputs=[imagebox, textbox],
143
+ )
144
+
145
+ with gr.Column(scale=6):
146
+ chatbot = gr.Chatbot(
147
+ elem_id="chatbot",
148
+ label="Heron Chatbot",
149
+ visible=True,
150
+ height=550,
151
+ avatar_images=("./images/user_icon.png", "./images/heron.png"),
152
+ )
153
+ with gr.Row():
154
+ with gr.Column(scale=8):
155
+ textbox.render()
156
+ with gr.Column(scale=1, min_width=60):
157
+ submit_btn = gr.Button(value="Submit", visible=True)
158
+ with gr.Row():
159
+ regenerate_btn = gr.Button(value="Regenerate", visible=True)
160
+ clear_btn = gr.Button(value="Clear history", visible=True)
161
+
162
+ regenerate_btn.click(regenerate, chatbot, chatbot).then(
163
+ stream_bot,
164
+ [imagebox, chatbot],
165
+ [chatbot],
166
+ )
167
+ clear_btn.click(clear_history, None, [chatbot, textbox, imagebox])
168
+
169
+ textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
170
+ stream_bot,
171
+ [imagebox, chatbot],
172
+ [chatbot],
173
+ )
174
+ submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
175
+ stream_bot,
176
+ [imagebox, chatbot],
177
+ [chatbot],
178
+ )
179
+
180
+ return demo
181
+
182
+
183
+ if __name__ == "__main__":
184
+ EOS_WORDS = "##"
185
+
186
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
+ max_length = 512
188
+
189
+ vision_model_name = "openai/clip-vit-large-patch14-336"
190
+ MODEL_NAME = "turing-motors/heron-chat-git-ELYZA-fast-7b-v0"
191
+ PROCESSOR_PATH = "turing-motors/heron-chat-git-ELYZA-fast-7b-v0"
192
+
193
+ # prepare a pretrained model
194
+ git_config = GitLlamaConfig.from_pretrained(MODEL_NAME)
195
+ git_config.set_vision_configs(
196
+ num_image_with_embedding=1, vision_model_name=vision_model_name
197
+ )
198
+ model = GitLlamaForCausalLM.from_pretrained(
199
+ MODEL_NAME, config=git_config, torch_dtype=torch.float16
200
+ )
201
+
202
+ model.eval()
203
+ model.to(device)
204
+
205
+ # prepare a processor
206
+ processor = AutoProcessor.from_pretrained(PROCESSOR_PATH)
207
+
208
+ demo = build_demo()
209
+ demo.queue(concurrency_count=1, max_size=5, api_open=False).launch()
images/heron.png ADDED
images/user_icon.png ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate
2
+ protobuf
3
+ sentencepiece
4
+ torch>=2.0.1
5
+ pillow
6
+ transformers