Yiqin commited on
Commit
f99efcc
·
1 Parent(s): 49b3986

split different users' data

Browse files
.gitignore CHANGED
@@ -3,6 +3,7 @@ output_*/
3
  icl_inference_output/
4
  .vscode/
5
  tmp/
 
6
 
7
  # Byte-compiled / optimized / DLL files
8
  __pycache__/
 
3
  icl_inference_output/
4
  .vscode/
5
  tmp/
6
+ gradio_cached_examples/
7
 
8
  # Byte-compiled / optimized / DLL files
9
  __pycache__/
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import argparse
2
  import time
3
 
4
  import gradio as gr
@@ -7,64 +6,73 @@ from config.config_utils import get_config
7
  from model import Captioner, VicunaHandler
8
 
9
 
10
- def set_example_video(example: list) -> dict:
11
- return gr.Video.update(value=example[0])
12
 
13
 
14
- def upload_file(files):
15
- file_paths = [file.name for file in files]
16
- return file_paths
17
 
18
 
19
- def upload_video(video):
20
- print(video)
21
- return video
22
 
23
 
24
- def respond(input, chat_history):
25
- bot_response = handler.gr_chat(input)
26
  chat_history.append((input, bot_response))
27
  time.sleep(0.1)
28
- return "", chat_history
29
-
30
-
31
- def clear_chat(chat_history):
32
- handler.chatbot.clear_conv_()
33
-
34
- return "", []
35
-
36
 
37
 
 
38
  config = get_config('config/infer.yaml')
39
-
40
- captioner = Captioner(config) # global
41
-
42
- global handler
43
  handler = VicunaHandler(config['vicuna'])
44
 
45
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
46
- gr.Markdown("## <h1><center>ChatVID</center></h1>")
47
- gr.Markdown("""
48
- ChatVID is a video chatbot that can chat about any video.
49
- """)
 
 
 
 
 
 
50
  with gr.Row():
51
  with gr.Column():
52
  video_path = gr.Video(label="Video")
53
 
54
  with gr.Column():
55
- upload_button = gr.Button(
56
- "Upload & Watch. (Click once and wait 3min )")
57
- chat_button = gr.Button("Let's Chat!", interactive=False)
58
  num_frames = gr.Slider(
59
  minimum=5,
60
  value=12,
61
  maximum=12,
62
  step=1,
63
- label="Number of frames (no more than 12)")
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  with gr.Column():
 
66
  chatbot = gr.Chatbot()
67
- captions = gr.State("")
 
68
  with gr.Row(visible=False) as input:
69
  with gr.Column(scale=0.7):
70
  txt = gr.Textbox(
@@ -76,22 +84,20 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
76
  with gr.Column(scale=0.15, min_width=0):
77
  clear_button = gr.Button("CLEAR")
78
 
79
- upload_button.click(
80
- lambda: gr.update(interactive=False), None, chat_button).then(
81
- lambda: gr.update(visible=False), None,
82
- input).then(lambda: [], None, chatbot).then(
83
- captioner.caption_video, [video_path, num_frames],
84
- [captions]).then(lambda: gr.update(interactive=True), None,
85
- chat_button)
86
-
87
- chat_button.click(handler.gr_chatbot_init, [captions],
88
- None).then(lambda: gr.update(visible=True), None,
89
- input)
90
-
91
- txt.submit(respond, inputs=[txt, chatbot], outputs=[txt, chatbot])
92
  run_button.click(
93
- respond, inputs=[txt, chatbot], outputs=[txt, chatbot])
94
  clear_button.click(
95
- clear_chat, inputs=[chatbot], outputs=[txt, chatbot])
96
 
97
  demo.launch()
 
 
1
  import time
2
 
3
  import gradio as gr
 
6
  from model import Captioner, VicunaHandler
7
 
8
 
9
+ def mirror(x):
10
+ return x
11
 
12
 
13
+ def clear_chat(conv_template):
14
+ return "", [], conv_template
 
15
 
16
 
17
+ def clear_four():
18
+ return [], [], [], []
 
19
 
20
 
21
+ def respond(input, chat_history, conv):
22
+ bot_response, new_conv = handler.gr_chat(input, conv)
23
  chat_history.append((input, bot_response))
24
  time.sleep(0.1)
25
+ return "", chat_history, new_conv
 
 
 
 
 
 
 
26
 
27
 
28
+ # global variables
29
  config = get_config('config/infer.yaml')
30
+ captioner = Captioner(config)
 
 
 
31
  handler = VicunaHandler(config['vicuna'])
32
 
33
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
34
+ gr.Markdown(
35
+ "## <h1><center><img src='https://github.com/InvincibleWyq/ChatVID/assets/37479394/1a7f47ca-ffbd-4720-b43a-4304fcaa8657' height=40/> ChatVID</center></h1>"
36
+ )
37
+ gr.Markdown("""🔥 [ChatVID](https://github.com/InvincibleWyq/ChatVID) is a
38
+ video chatbot. Please give us a ⭐ Star!""")
39
+ gr.Markdown("""🎥 You may use the example video by clicking it.""")
40
+ gr.Markdown("""🚀 For any questions or suggestions, feel free to drop Yiqin
41
+ an email at <a href="mailto:wyq1217@outlook.com">wyq1217@outlook.com</a>
42
+ or open an issue.""")
43
+
44
  with gr.Row():
45
  with gr.Column():
46
  video_path = gr.Video(label="Video")
47
 
48
  with gr.Column():
49
+ upload_button = gr.Button("""Upload & Process.
50
+ (Click and wait 3min until dialog box appears)""")
51
+
52
  num_frames = gr.Slider(
53
  minimum=5,
54
  value=12,
55
  maximum=12,
56
  step=1,
57
+ label="Number of frames")
58
+
59
+ gr.Markdown("## Video Examples")
60
+ gr.Examples(
61
+ examples=[
62
+ "examples/cook_720p.mp4",
63
+ "examples/temple_of_heaven_720p.mp4"
64
+ ],
65
+ inputs=video_path,
66
+ outputs=video_path,
67
+ fn=mirror,
68
+ cache_examples=True,
69
+ )
70
 
71
  with gr.Column():
72
+ caption_box = gr.Textbox("")
73
  chatbot = gr.Chatbot()
74
+ conv_template = gr.State("") # determined by the video
75
+ conv = gr.State("") # updated thourghout the conversation
76
  with gr.Row(visible=False) as input:
77
  with gr.Column(scale=0.7):
78
  txt = gr.Textbox(
 
84
  with gr.Column(scale=0.15, min_width=0):
85
  clear_button = gr.Button("CLEAR")
86
 
87
+ # conv_template and conv are `Conversation` objects
88
+ upload_button.click(lambda: gr.update(visible=False), None, input).then(
89
+ clear_four, None, [chatbot, conv, conv_template, caption_box]).then(
90
+ captioner.caption_video, [video_path, num_frames],
91
+ [conv_template]).then(mirror, [conv_template], [caption_box]).then(
92
+ handler.gr_chatbot_init, [conv_template],
93
+ [conv_template, conv]).then(lambda: gr.update(visible=True),
94
+ None, input)
95
+
96
+ txt.submit(
97
+ respond, inputs=[txt, chatbot, conv], outputs=[txt, chatbot, conv])
 
 
98
  run_button.click(
99
+ respond, inputs=[txt, chatbot, conv], outputs=[txt, chatbot, conv])
100
  clear_button.click(
101
+ clear_chat, inputs=[conv_template], outputs=[txt, chatbot, conv])
102
 
103
  demo.launch()
examples/cook_720p.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa232686c22066b90fe099e9bb4f0ad093693685368eb7590ddd843deb40f574
3
+ size 5320367
examples/references.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ The video [Temple of Heaven - UNESCO World Heritage Site] by YouTube creator [World Heritage Journey] used under Fair Use.
2
+ Link:
3
+ https://www.youtube.com/watch?v=9xLoyYY_5rc
4
+
5
+ The video [做饭糊弄学 十分钟晚餐 今天吃 :番茄西兰花炒蛋] from Bilibili user [香蕉柿子梨] used under Fair Use.
6
+ Link:
7
+ https://www.bilibili.com/video/BV1RY411e74Z
examples/temple_of_heaven_720p.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e60ce16122b0a6277c10efc5c37cc9b89c963913b2d382539fb7cc101bbd0851
3
+ size 4217766
model/Vicuna.py CHANGED
@@ -1,8 +1,6 @@
1
  from model.fastchat.conversation import (Conversation, SeparatorStyle,
2
- compute_skip_echo_len,
3
- get_default_conv_template)
4
- from model.fastchat.serve.inference import (ChatIO, chat_loop, generate_stream,
5
- load_model)
6
 
7
 
8
  class SimpleChatIO(ChatIO):
@@ -35,7 +33,6 @@ class VicunaChatBot:
35
  num_gpus: str,
36
  max_gpu_memory: str,
37
  load_8bit: bool,
38
- conv_template,
39
  ChatIO: ChatIO,
40
  debug: bool,
41
  ):
@@ -48,25 +45,18 @@ class VicunaChatBot:
48
  num_gpus, max_gpu_memory,
49
  load_8bit, debug)
50
 
51
- if conv_template:
52
- self.conv = conv_template.copy()
53
- else:
54
- self.conv = get_default_conv_template(model_path).copy()
55
-
56
- self.conv_template = self.conv.copy()
57
-
58
- def chat(self, inp: str, temperature: float, max_new_tokens: int):
59
  """ Vicuna as a chatbot. """
60
- self.conv.append_message(self.conv.roles[0], inp)
61
- self.conv.append_message(self.conv.roles[1], None)
62
 
63
  generate_stream_func = generate_stream
64
- prompt = self.conv.get_prompt()
65
 
66
- skip_echo_len = compute_skip_echo_len(self.model_path, self.conv,
67
- prompt)
68
  stop_str = (
69
- self.conv.sep if self.conv.sep_style
70
  in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None)
71
  params = {
72
  "model": self.model_path,
@@ -76,65 +66,13 @@ class VicunaChatBot:
76
  "stop": stop_str,
77
  }
78
  print(prompt)
79
- self.chatio.prompt_for_output(self.conv.roles[1])
80
  output_stream = generate_stream_func(self.model, self.tokenizer,
81
  params, self.device)
82
  outputs = self.chatio.stream_output(output_stream, skip_echo_len)
83
  # NOTE: strip is important to align with the training data.
84
- self.conv.messages[-1][-1] = outputs.strip()
85
- return outputs
86
-
87
- def summarise(self, caption: dict, temperature: float,
88
- max_new_tokens: int):
89
- """ Vicuna as a summariser. """
90
- questions = caption
91
- captions = {}
92
- for id, question in questions.items():
93
- # Reset the conversation for each iteration
94
- self.conv = get_default_conv_template(self.model_path).copy()
95
- self.conv.append_message(self.conv.roles[0], question)
96
- self.conv.append_message(self.conv.roles[1], None)
97
-
98
- generate_stream_func = generate_stream
99
- prompt = self.conv.get_prompt()
100
-
101
- skip_echo_len = compute_skip_echo_len(self.model_path, self.conv,
102
- prompt)
103
- stop_str = (
104
- self.conv.sep if self.conv.sep_style
105
- in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None)
106
-
107
- params = {
108
- "model": self.model_path,
109
- "prompt": prompt,
110
- "temperature": temperature,
111
- "max_new_tokens": max_new_tokens,
112
- "stop": stop_str,
113
- }
114
-
115
- self.chatio.prompt_for_output(self.conv.roles[1])
116
- output_stream = generate_stream_func(self.model, self.tokenizer,
117
- params, self.device)
118
- outputs = self.chatio.stream_output(output_stream, skip_echo_len)
119
- captions[id] = outputs
120
-
121
- if self.debug:
122
- print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
123
-
124
- print(captions)
125
- return captions
126
-
127
- def clear_conv_(self):
128
- """ Clear the conversation. """
129
- self.conv = self.conv_template.copy()
130
-
131
- def change_conv_template_(self, conv_template):
132
- self.conv_template = conv_template.copy()
133
- self.conv = conv_template.copy()
134
-
135
- def change_conv_(self, conv_template):
136
- """ Change the conversation. """
137
- self.conv = conv_template.copy()
138
 
139
 
140
  class VicunaHandler:
@@ -150,38 +88,25 @@ class VicunaHandler:
150
  self.config['num_gpus'],
151
  self.config['max_gpu_memory'],
152
  self.config['load_8bit'],
153
- None,
154
  self.chat_io,
155
  self.config['debug'],
156
  )
157
 
158
  def chat(self):
159
  """ Chat with the Vicuna. """
160
- template = self._construct_conversation("")
161
- chat_loop(
162
- self.config['model_path'],
163
- self.config['device'],
164
- self.config['num_gpus'],
165
- self.config['max_gpu_memory'],
166
- self.config['load_8bit'],
167
- template,
168
- self.config['temperature'],
169
- self.config['max_new_tokens'],
170
- self.chat_io,
171
- self.config['debug'],
172
- )
173
 
174
  def gr_chatbot_init(self, caption: str):
175
  """ Initialise the chatbot for gradio. """
176
 
177
  template = self._construct_conversation(caption)
178
- self.chatbot.change_conv_template_(template)
179
  print("Chatbot initialised.")
 
180
 
181
- def gr_chat(self, inp):
182
  """ Chat using gradio as the frontend. """
183
  return self.chatbot.chat(inp, self.config['temperature'],
184
- self.config['max_new_tokens'])
185
 
186
  def _construct_conversation(self, prompt):
187
  """ Construct a conversation template.
 
1
  from model.fastchat.conversation import (Conversation, SeparatorStyle,
2
+ compute_skip_echo_len)
3
+ from model.fastchat.serve.inference import ChatIO, generate_stream, load_model
 
 
4
 
5
 
6
  class SimpleChatIO(ChatIO):
 
33
  num_gpus: str,
34
  max_gpu_memory: str,
35
  load_8bit: bool,
 
36
  ChatIO: ChatIO,
37
  debug: bool,
38
  ):
 
45
  num_gpus, max_gpu_memory,
46
  load_8bit, debug)
47
 
48
+ def chat(self, inp: str, temperature: float, max_new_tokens: int,
49
+ conv: Conversation):
 
 
 
 
 
 
50
  """ Vicuna as a chatbot. """
51
+ conv.append_message(conv.roles[0], inp)
52
+ conv.append_message(conv.roles[1], None)
53
 
54
  generate_stream_func = generate_stream
55
+ prompt = conv.get_prompt()
56
 
57
+ skip_echo_len = compute_skip_echo_len(self.model_path, conv, prompt)
 
58
  stop_str = (
59
+ conv.sep if conv.sep_style
60
  in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None)
61
  params = {
62
  "model": self.model_path,
 
66
  "stop": stop_str,
67
  }
68
  print(prompt)
69
+ self.chatio.prompt_for_output(conv.roles[1])
70
  output_stream = generate_stream_func(self.model, self.tokenizer,
71
  params, self.device)
72
  outputs = self.chatio.stream_output(output_stream, skip_echo_len)
73
  # NOTE: strip is important to align with the training data.
74
+ conv.messages[-1][-1] = outputs.strip()
75
+ return outputs, conv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  class VicunaHandler:
 
88
  self.config['num_gpus'],
89
  self.config['max_gpu_memory'],
90
  self.config['load_8bit'],
 
91
  self.chat_io,
92
  self.config['debug'],
93
  )
94
 
95
  def chat(self):
96
  """ Chat with the Vicuna. """
97
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def gr_chatbot_init(self, caption: str):
100
  """ Initialise the chatbot for gradio. """
101
 
102
  template = self._construct_conversation(caption)
 
103
  print("Chatbot initialised.")
104
+ return template.copy(), template.copy()
105
 
106
+ def gr_chat(self, inp, conv: Conversation):
107
  """ Chat using gradio as the frontend. """
108
  return self.chatbot.chat(inp, self.config['temperature'],
109
+ self.config['max_new_tokens'], conv)
110
 
111
  def _construct_conversation(self, prompt):
112
  """ Construct a conversation template.