jZoNg commited on
Commit
3360664
1 Parent(s): 3ccf1a4

remove comments

Browse files
Files changed (2) hide show
  1. app.py +2 -32
  2. model.py +6 -36
app.py CHANGED
@@ -41,7 +41,6 @@ def delete_prev_fn(
41
  def generate(
42
  message: str,
43
  history_with_input: list[tuple[str, str]],
44
- system_prompt: str,
45
  max_new_tokens: int,
46
  temperature: float,
47
  top_p: float,
@@ -51,7 +50,7 @@ def generate(
51
  raise ValueError
52
 
53
  history = history_with_input[:-1]
54
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
55
  try:
56
  first_response = next(generator)
57
  yield history + [(message, first_response)]
@@ -61,22 +60,11 @@ def generate(
61
  yield history + [(message, response)]
62
 
63
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
64
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, 1, 0.95, 5)
65
  for x in generator:
66
  pass
67
  return '', x
68
 
69
- def check_input_token_length(
70
- message: str,
71
- chat_history: list[tuple[str, str]],
72
- system_prompt: str
73
- ) -> None:
74
- a = 1
75
- # input_token_length = get_input_token_length(message, chat_history, system_prompt)
76
- # if input_token_length > MAX_INPUT_TOKEN_LENGTH:
77
- # raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
78
-
79
-
80
  with gr.Blocks(css='style.css') as demo:
81
  gr.Markdown(DESCRIPTION)
82
  gr.DuplicateButton(
@@ -108,11 +96,6 @@ with gr.Blocks(css='style.css') as demo:
108
  saved_input = gr.State()
109
 
110
  with gr.Accordion(label='Advanced options', open=False):
111
- system_prompt = gr.Textbox(
112
- label='System prompt',
113
- value=DEFAULT_SYSTEM_PROMPT,
114
- lines=6
115
- )
116
  max_new_tokens = gr.Slider(
117
  label='Max new tokens',
118
  minimum=1,
@@ -170,16 +153,10 @@ with gr.Blocks(css='style.css') as demo:
170
  api_name=False,
171
  queue=False,
172
  ).then(
173
- fn=check_input_token_length,
174
- inputs=[saved_input, chatbot, system_prompt],
175
- api_name=False,
176
- queue=False,
177
- ).success(
178
  fn=generate,
179
  inputs=[
180
  saved_input,
181
  chatbot,
182
- system_prompt,
183
  max_new_tokens,
184
  temperature,
185
  top_p,
@@ -202,16 +179,10 @@ with gr.Blocks(css='style.css') as demo:
202
  api_name=False,
203
  queue=False,
204
  ).then(
205
- fn=check_input_token_length,
206
- inputs=[saved_input, chatbot, system_prompt],
207
- api_name=False,
208
- queue=False,
209
- ).success(
210
  fn=generate,
211
  inputs=[
212
  saved_input,
213
  chatbot,
214
- system_prompt,
215
  max_new_tokens,
216
  temperature,
217
  top_p,
@@ -238,7 +209,6 @@ with gr.Blocks(css='style.css') as demo:
238
  inputs=[
239
  saved_input,
240
  chatbot,
241
- system_prompt,
242
  max_new_tokens,
243
  temperature,
244
  top_p,
 
41
  def generate(
42
  message: str,
43
  history_with_input: list[tuple[str, str]],
 
44
  max_new_tokens: int,
45
  temperature: float,
46
  top_p: float,
 
50
  raise ValueError
51
 
52
  history = history_with_input[:-1]
53
+ generator = run(message, history, max_new_tokens, temperature, top_p, top_k)
54
  try:
55
  first_response = next(generator)
56
  yield history + [(message, first_response)]
 
60
  yield history + [(message, response)]
61
 
62
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
63
+ generator = generate(message, [], DEFAULT_MAX_NEW_TOKENS, 1, 0.95, 5)
64
  for x in generator:
65
  pass
66
  return '', x
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  with gr.Blocks(css='style.css') as demo:
69
  gr.Markdown(DESCRIPTION)
70
  gr.DuplicateButton(
 
96
  saved_input = gr.State()
97
 
98
  with gr.Accordion(label='Advanced options', open=False):
 
 
 
 
 
99
  max_new_tokens = gr.Slider(
100
  label='Max new tokens',
101
  minimum=1,
 
153
  api_name=False,
154
  queue=False,
155
  ).then(
 
 
 
 
 
156
  fn=generate,
157
  inputs=[
158
  saved_input,
159
  chatbot,
 
160
  max_new_tokens,
161
  temperature,
162
  top_p,
 
179
  api_name=False,
180
  queue=False,
181
  ).then(
 
 
 
 
 
182
  fn=generate,
183
  inputs=[
184
  saved_input,
185
  chatbot,
 
186
  max_new_tokens,
187
  temperature,
188
  top_p,
 
209
  inputs=[
210
  saved_input,
211
  chatbot,
 
212
  max_new_tokens,
213
  temperature,
214
  top_p,
model.py CHANGED
@@ -24,41 +24,18 @@ tokenizer = AutoTokenizer.from_pretrained(
24
  trust_remote_code=True
25
  )
26
 
27
- def get_prompt(
28
- message: str,
29
- chat_history: list[tuple[str, str]],
30
- system_prompt: str
31
- ) -> str:
32
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
33
- # The first user input is _not_ stripped
34
- do_strip = False
35
- for user_input, response in chat_history:
36
- user_input = user_input.strip() if do_strip else user_input
37
- do_strip = True
38
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
39
- message = message.strip() if do_strip else message
40
- texts.append(f'{message} [/INST]')
41
- return ''.join(texts)
42
-
43
- def get_input_token_length(
44
- message: str,
45
- chat_history: list[tuple[str, str]],
46
- system_prompt: str
47
- ) -> int:
48
- prompt = get_prompt(message, chat_history, system_prompt)
49
- input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
50
- return input_ids.shape[-1]
51
-
52
  def run(
53
  message: str,
54
  chat_history: list[tuple[str, str]],
55
- system_prompt: str,
56
  max_new_tokens: int = 1024,
57
  temperature: float = 1.0,
58
  top_p: float = 0.95,
59
  top_k: int = 5
60
  ) -> Iterator[str]:
61
- print(chat_history)
 
 
 
62
 
63
  history = []
64
  result=""
@@ -74,15 +51,8 @@ def run(
74
  for response in model.chat(
75
  tokenizer,
76
  history,
77
- # stream=True,
78
- # max_new_tokens=max_new_tokens,
79
- # temperature=temperature,
80
- # top_p=top_p,
81
- # top_k=top_k,
82
  ):
83
- print(response)
84
  result = result + response
85
  yield result
86
- # if "content" in response["choices"][0]["delta"]:
87
- # result = result + response["choices"][0]["delta"]["content"]
88
- # yield result
 
24
  trust_remote_code=True
25
  )
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def run(
28
  message: str,
29
  chat_history: list[tuple[str, str]],
 
30
  max_new_tokens: int = 1024,
31
  temperature: float = 1.0,
32
  top_p: float = 0.95,
33
  top_k: int = 5
34
  ) -> Iterator[str]:
35
+ model.generation_config.max_new_tokens = max_new_tokens
36
+ model.generation_config.temperature = temperature
37
+ model.generation_config.top_p = top_p
38
+ model.generation_config.top_k = top_k
39
 
40
  history = []
41
  result=""
 
51
  for response in model.chat(
52
  tokenizer,
53
  history,
54
+ stream=True,
 
 
 
 
55
  ):
 
56
  result = result + response
57
  yield result
58
+