gustavoaq commited on
Commit
7ee6941
β€’
1 Parent(s): 5c471d1

Update demo/app.py

Browse files
Files changed (1) hide show
  1. demo/app.py +67 -80
demo/app.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
  import sys
5
  import gradio as gr
6
  import torch
 
7
  from app_modules.utils import *
8
  from app_modules.presets import *
9
  from app_modules.overwrites import *
@@ -15,53 +16,48 @@ logging.basicConfig(
15
 
16
  base_model = "decapoda-research/llama-7b-hf"
17
  adapter_model = "/home/user/app/checkpoint-100"
18
- tokenizer, model, device = load_tokenizer_and_model(base_model, adapter_model)
19
-
20
-
21
- def predict(
22
- text,
23
- chatbot,
24
- history,
25
- top_p,
26
- temperature,
27
- max_length_tokens,
28
- max_context_length_tokens,
29
- ):
30
- if text == "":
31
- yield chatbot, history, "Empty context."
 
 
 
32
  return
33
 
34
- inputs = generate_prompt_with_history(
35
- text, history, tokenizer, max_length=max_context_length_tokens
36
- )
37
  if inputs is None:
38
- yield chatbot, history, "Input too long."
39
- return
40
  else:
41
- prompt, inputs = inputs
42
  begin_length = len(prompt)
43
- input_ids = inputs["input_ids"][:, -max_context_length_tokens:].to(device)
44
  torch.cuda.empty_cache()
45
-
 
 
 
 
46
  with torch.no_grad():
47
- for x in sample_decode(
48
- input_ids,
49
- model,
50
- tokenizer,
51
- stop_words=["[|Human|]", "[|AI|]"],
52
- max_length=max_length_tokens,
53
- temperature=temperature,
54
- top_p=top_p,
55
- ):
56
- if is_stop_word_or_prefix(x, ["[|Human|]", "[|AI|]"]) is False:
57
  if "[|Human|]" in x:
58
- x = x[: x.index("[|Human|]")].strip()
59
  if "[|AI|]" in x:
60
- x = x[: x.index("[|AI|]")].strip()
61
- x = x.strip(" ")
62
- a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [
63
- [text, convert_to_markdown(x)]
64
- ], history + [[text, x]]
65
  yield a, b, "Generating..."
66
  if shared_state.interrupted:
67
  shared_state.recover()
@@ -70,40 +66,33 @@ def predict(
70
  return
71
  except:
72
  pass
 
 
73
  torch.cuda.empty_cache()
74
- print(prompt)
75
- print(x)
76
- print("=" * 80)
77
  try:
78
- yield a, b, "Generate: Success"
79
  except:
80
  pass
81
-
82
-
83
  def retry(
84
- text,
85
- chatbot,
86
- history,
87
- top_p,
88
- temperature,
89
- max_length_tokens,
90
- max_context_length_tokens,
91
- ):
92
- logging.info("Retry...")
93
- if len(history) == 0:
94
- yield chatbot, history, "Empty context."
95
- return
96
- chatbot.pop()
97
- inputs = history.pop()[0]
98
- for x in predict(
99
- inputs,
100
  chatbot,
101
  history,
102
  top_p,
103
  temperature,
104
  max_length_tokens,
105
  max_context_length_tokens,
106
- ):
 
 
 
 
 
 
 
107
  yield x
108
 
109
 
@@ -132,13 +121,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
132
  submitBtn = gr.Button("Send")
133
  with gr.Column(min_width=70, scale=1):
134
  cancelBtn = gr.Button("Stop")
135
-
136
  with gr.Row(scale=1):
137
  emptyBtn = gr.Button(
138
  "🧹 New Conversation",
139
  )
140
  retryBtn = gr.Button("πŸ”„ Regenerate")
141
- delLastBtn = gr.Button("πŸ—‘οΈ Remove Last Turn")
142
  with gr.Column():
143
  with gr.Column(min_width=50, scale=1):
144
  with gr.Tab(label="Parameter Setting"):
@@ -162,7 +150,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
162
  max_length_tokens = gr.Slider(
163
  minimum=0,
164
  maximum=512,
165
- value=512,
166
  step=8,
167
  interactive=True,
168
  label="Max Generation Tokens",
@@ -206,20 +194,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
206
  show_progress=True,
207
  )
208
 
209
- reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display])
210
-
 
 
211
  # Chatbot
212
- cancelBtn.click(cancel_outputing, [], [status_display])
213
  transfer_input_args = dict(
214
- fn=transfer_input,
215
- inputs=[user_input],
216
- outputs=[user_question, user_input, submitBtn, cancelBtn],
217
- show_progress=True,
218
  )
219
 
220
- user_input.submit(**transfer_input_args).then(**predict_args)
221
 
222
- submitBtn.click(**transfer_input_args).then(**predict_args)
223
 
224
  emptyBtn.click(
225
  reset_state,
@@ -228,7 +214,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
228
  )
229
  emptyBtn.click(**reset_args)
230
 
231
- retryBtn.click(**retry_args)
232
 
233
  delLastBtn.click(
234
  delete_last_conversation,
@@ -236,11 +222,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
236
  [chatbot, history, status_display],
237
  show_progress=True,
238
  )
239
-
 
 
 
 
 
240
  demo.title = "Baize"
241
 
242
- if __name__ == "__main__":
243
- reload_javascript()
244
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
245
- share=False, favicon_path="/home/user/app/demo/assets/favicon.ico", inbrowser=True
246
- )
 
4
  import sys
5
  import gradio as gr
6
  import torch
7
+ import gc
8
  from app_modules.utils import *
9
  from app_modules.presets import *
10
  from app_modules.overwrites import *
 
16
 
17
  base_model = "decapoda-research/llama-7b-hf"
18
  adapter_model = "/home/user/app/checkpoint-100"
19
+ tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
20
+
21
+ total_count = 0
22
+ def predict(text,
23
+ chatbot,
24
+ history,
25
+ top_p,
26
+ temperature,
27
+ max_length_tokens,
28
+ max_context_length_tokens,):
29
+ if text=="":
30
+ yield chatbot,history,"Empty context."
31
+ return
32
+ try:
33
+ model
34
+ except:
35
+ yield [[text,"No Model Found"]],[],"No Model Found"
36
  return
37
 
38
+ inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
 
 
39
  if inputs is None:
40
+ yield chatbot,history,"Input too long."
41
+ return
42
  else:
43
+ prompt,inputs=inputs
44
  begin_length = len(prompt)
45
+ input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
46
  torch.cuda.empty_cache()
47
+ global total_count
48
+ total_count += 1
49
+ print(total_count)
50
+ if total_count % 50 == 0 :
51
+ os.system("nvidia-smi")
52
  with torch.no_grad():
53
+ for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
54
+ if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
 
 
 
 
 
 
 
 
55
  if "[|Human|]" in x:
56
+ x = x[:x.index("[|Human|]")].strip()
57
  if "[|AI|]" in x:
58
+ x = x[:x.index("[|AI|]")].strip()
59
+ x = x.strip()
60
+ a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
 
 
61
  yield a, b, "Generating..."
62
  if shared_state.interrupted:
63
  shared_state.recover()
 
66
  return
67
  except:
68
  pass
69
+ del input_ids
70
+ gc.collect()
71
  torch.cuda.empty_cache()
72
+ #print(text)
73
+ #print(x)
74
+ #print("="*80)
75
  try:
76
+ yield a,b,"Generate: Success"
77
  except:
78
  pass
79
+
 
80
  def retry(
81
+ text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  chatbot,
83
  history,
84
  top_p,
85
  temperature,
86
  max_length_tokens,
87
  max_context_length_tokens,
88
+ ):
89
+ logging.info("Retry...")
90
+ if len(history) == 0:
91
+ yield chatbot, history, f"Empty context"
92
+ return
93
+ chatbot.pop()
94
+ inputs = history.pop()[0]
95
+ for x in predict(inputs,chatbot,history,top_p,temperature,max_length_tokens,max_context_length_tokens):
96
  yield x
97
 
98
 
 
121
  submitBtn = gr.Button("Send")
122
  with gr.Column(min_width=70, scale=1):
123
  cancelBtn = gr.Button("Stop")
 
124
  with gr.Row(scale=1):
125
  emptyBtn = gr.Button(
126
  "🧹 New Conversation",
127
  )
128
  retryBtn = gr.Button("πŸ”„ Regenerate")
129
+ delLastBtn = gr.Button("πŸ—‘οΈ Remove Last Turn")
130
  with gr.Column():
131
  with gr.Column(min_width=50, scale=1):
132
  with gr.Tab(label="Parameter Setting"):
 
150
  max_length_tokens = gr.Slider(
151
  minimum=0,
152
  maximum=512,
153
+ value=256,
154
  step=8,
155
  interactive=True,
156
  label="Max Generation Tokens",
 
194
  show_progress=True,
195
  )
196
 
197
+ reset_args = dict(
198
+ fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
199
+ )
200
+
201
  # Chatbot
 
202
  transfer_input_args = dict(
203
+ fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
 
 
 
204
  )
205
 
206
+ predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
207
 
208
+ predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
209
 
210
  emptyBtn.click(
211
  reset_state,
 
214
  )
215
  emptyBtn.click(**reset_args)
216
 
217
+ predict_event3 = retryBtn.click(**retry_args)
218
 
219
  delLastBtn.click(
220
  delete_last_conversation,
 
222
  [chatbot, history, status_display],
223
  show_progress=True,
224
  )
225
+ cancelBtn.click(
226
+ cancel_outputing, [], [status_display],
227
+ cancels=[
228
+ predict_event1,predict_event2,predict_event3
229
+ ]
230
+ )
231
  demo.title = "Baize"
232
 
233
+ demo.queue(concurrency_count=1).launch()