tanahhh commited on
Commit
5205227
1 Parent(s): 590285e

remove demo

Browse files
Files changed (2) hide show
  1. app.py +2 -205
  2. requirements.txt +1 -8
app.py CHANGED
@@ -1,219 +1,16 @@
1
- import logging
2
- import os
3
- import sys
4
- import time
5
- from threading import Thread
6
-
7
  import gradio as gr
8
- import torch
9
- from transformers import AutoProcessor, LlamaTokenizer, StoppingCriteria, TextIteratorStreamer
10
-
11
- os.system("git clone https://github.com/turingmotors/heron && cd heron && pip install -e .")
12
-
13
- sys.path.insert(0, "./heron")
14
- from heron.models.video_blip import VideoBlipForConditionalGeneration, VideoBlipProcessor
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
- title_markdown = """
19
- # Heronチャットデモ
20
-
21
- - モデル: [turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0](https://huggingface.co/turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0)
22
- - 学習コード: [Heron](https://github.com/turingmotors/heron)
23
- """
24
-
25
-
26
- # This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74
27
- class KeywordsStoppingCriteria(StoppingCriteria):
28
- def __init__(self, keywords, tokenizer, input_ids):
29
- self.keywords = keywords
30
- self.keyword_ids = []
31
- for keyword in keywords:
32
- cur_keyword_ids = tokenizer(keyword).input_ids
33
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
34
- cur_keyword_ids = cur_keyword_ids[1:]
35
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
36
- self.tokenizer = tokenizer
37
- self.start_len = input_ids.shape[1]
38
-
39
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
40
- assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
41
- offset = min(output_ids.shape[1] - self.start_len, 3)
42
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
43
- for keyword_id in self.keyword_ids:
44
- if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
45
- return True
46
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
47
- for keyword in self.keywords:
48
- if keyword in outputs:
49
- return True
50
- return False
51
-
52
-
53
- def preprocess(history, image):
54
- text = ""
55
- for one_history in history:
56
- text += f"##human: {one_history[0]}\n##gpt: "
57
- # do preprocessing
58
- inputs = processor(
59
- text=text,
60
- images=image,
61
- return_tensors="pt",
62
- truncation=True,
63
- )
64
- inputs = {k: v.to(device) for k, v in inputs.items()}
65
- inputs["pixel_values"] = inputs["pixel_values"].to(device, torch.float16)
66
- return inputs
67
-
68
-
69
- def add_text(textbox, history):
70
- history = history + [(textbox, None)]
71
- return "", history
72
-
73
-
74
- def stream_bot(imagebox, history):
75
- # do preprocessing
76
- inputs = preprocess(history, imagebox)
77
-
78
- # streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
79
- streamer = TextIteratorStreamer(
80
- processor.tokenizer,
81
- skip_prompt=True,
82
- skip_special_tokens=True,
83
- do_sample=False,
84
- temperature=0.2,
85
- no_repeat_ngram_size=2,
86
- )
87
- stopping_criteria = KeywordsStoppingCriteria(
88
- [EOS_WORDS], processor.tokenizer, inputs["input_ids"]
89
- )
90
 
91
- inputs.update(
92
- dict(
93
- streamer=streamer,
94
- max_new_tokens=max_length,
95
- stopping_criteria=[stopping_criteria],
96
- no_repeat_ngram_size=2,
97
- eos_token_id=[processor.tokenizer.pad_token_id],
98
- )
99
- )
100
- thread = Thread(target=model.generate, kwargs=inputs)
101
- thread.start()
102
-
103
- history[-1][1] = ""
104
- for new_text in streamer:
105
- history[-1][1] += new_text
106
- history[-1][1] = history[-1][1].replace(EOS_WORDS, "")
107
- time.sleep(0.05)
108
- yield history
109
-
110
-
111
- def regenerate(history):
112
- history[-1] = (history[-1][0], None)
113
- return history
114
-
115
-
116
- def clear_history():
117
- return [], "", None
118
 
 
119
 
120
  def build_demo():
121
- textbox = gr.Textbox(
122
- show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False
123
- )
124
  with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo:
125
  gr.Markdown(title_markdown)
126
- with gr.Row():
127
- with gr.Column(scale=6):
128
- imagebox = gr.Image(type="pil")
129
-
130
- # gr.Examples(
131
- # examples=[
132
- # [
133
- # "./images/bus_kyoto.png",
134
- # "この道路を運転する時には何に気をつけるべきですか?",
135
- # ],
136
- # [
137
- # "./images/bear.png",
138
- # "この画像には何が写っていますか?",
139
- # ],
140
- # [
141
- # "./images/water_bus.png",
142
- # "画像には何が写っていますか?",
143
- # ],
144
- # [
145
- # "./images/extreme_ironing.jpg",
146
- # "この画像の面白い点は何ですか?",
147
- # ],
148
- # [
149
- # "./images/heron.png",
150
- # "この画像はどういう点が面白いですか?",
151
- # ],
152
- # ],
153
- # inputs=[imagebox, textbox],
154
- # )
155
-
156
- with gr.Column(scale=6):
157
- chatbot = gr.Chatbot(
158
- elem_id="chatbot",
159
- label="Heron Chatbot",
160
- visible=True,
161
- height=550,
162
- avatar_images=("./images/user_icon.png", "./images/heron.png"),
163
- )
164
- with gr.Row():
165
- with gr.Column(scale=8):
166
- textbox.render()
167
- with gr.Column(scale=1, min_width=60):
168
- submit_btn = gr.Button(value="Submit", visible=True)
169
- with gr.Row():
170
- regenerate_btn = gr.Button(value="Regenerate", visible=True)
171
- clear_btn = gr.Button(value="Clear history", visible=True)
172
-
173
- regenerate_btn.click(regenerate, chatbot, chatbot).then(
174
- stream_bot,
175
- [imagebox, chatbot],
176
- [chatbot],
177
- )
178
- clear_btn.click(clear_history, None, [chatbot, textbox, imagebox])
179
-
180
- textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
181
- stream_bot,
182
- [imagebox, chatbot],
183
- [chatbot],
184
- )
185
- submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
186
- stream_bot,
187
- [imagebox, chatbot],
188
- [chatbot],
189
- )
190
 
191
  return demo
192
 
193
 
194
  if __name__ == "__main__":
195
- EOS_WORDS = "##"
196
-
197
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
- print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
199
- max_length = 512
200
- MODEL_NAME = "turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0"
201
-
202
- # prepare a pretrained model
203
- model = VideoBlipForConditionalGeneration.from_pretrained(
204
- MODEL_NAME, torch_dtype=torch.float16, ignore_mismatched_sizes=True
205
- )
206
-
207
- model = model.half()
208
- model.eval()
209
- model.to(device)
210
-
211
- # prepare a processor
212
- processor = VideoBlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
213
- tokenizer = LlamaTokenizer.from_pretrained(
214
- "novelai/nerdstash-tokenizer-v1", additional_special_tokens=["▁▁"]
215
- )
216
- processor.tokenizer = tokenizer
217
-
218
  demo = build_demo()
219
  demo.queue(max_size=10).launch()
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ title_markdown = "デモはこちら(https://9255-35-232-109-220.ngrok-free.app/)に移転しました。"
5
 
6
  def build_demo():
7
+
 
 
8
  with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo:
9
  gr.Markdown(title_markdown)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  return demo
12
 
13
 
14
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  demo = build_demo()
16
  demo.queue(max_size=10).launch()
requirements.txt CHANGED
@@ -1,8 +1 @@
1
- accelerate
2
- protobuf
3
- sentencepiece
4
- torch
5
- pillow
6
- transformers==4.33.1
7
- #accelerate==0.22.0
8
- deepspeed==0.10.2
 
1
+