louisbrulenaudet commited on
Commit
9bb1d67
1 Parent(s): 6718013

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -88
app.py CHANGED
@@ -87,98 +87,50 @@ model, tokenizer, description = setup(
87
  description=DESCRIPTION
88
  )
89
 
90
-
91
- def preprocess_conversation(
92
- message: str,
93
- chat_history: list,
94
- system_prompt: str
95
- ):
96
  """
97
- Preprocess the conversation history by formatting it appropriately.
98
 
99
  Parameters
100
  ----------
101
  message : str
102
- The user's message.
103
-
104
- chat_history : list
105
- The conversation history, where each element is a tuple (user_message, assistant_response).
106
-
107
- system_prompt : str
108
- The system prompt.
109
 
110
  Returns
111
  -------
112
- list
113
- The formatted conversation history.
114
- """
115
- conversation = []
116
-
117
- if system_prompt:
118
- conversation.append(
119
- {
120
- "role": "system",
121
- "content": system_prompt
122
- }
123
- )
124
-
125
- for user, assistant in chat_history:
126
- conversation.extend(
127
- [
128
- {
129
- "role": "user",
130
- "content": user
131
- },
132
- {
133
- "role": "assistant",
134
- "content": assistant
135
- }
136
- ]
137
- )
138
-
139
- conversation.append(
140
- {
141
- "role": "user",
142
- "content": message
143
- }
144
- )
145
-
146
- return conversation
147
-
148
-
149
- def trim_input_ids(
150
- input_ids,
151
- max_length
152
- ):
153
  """
154
- Trim the input token IDs if they exceed the maximum length.
155
-
156
- Parameters
157
- ----------
158
- input_ids : torch.Tensor
159
- The input token IDs.
160
 
161
- max_length : int
162
- The maximum length allowed.
 
 
163
 
164
- Returns
165
- -------
166
- torch.Tensor
167
- The trimmed input token IDs.
168
- """
169
- if input_ids.shape[1] > max_length:
170
- input_ids = input_ids[:, -max_length:]
171
- print(f"Trimmed input from conversation as it was longer than {max_length} tokens.")
172
-
173
- return input_ids
174
 
175
 
176
  @spaces.GPU
177
  def generate(
178
  message: str,
179
  chat_history: list,
180
- system_prompt: str,
181
- max_new_tokens: int = 1024,
182
  temperature: float = 0.6,
183
  top_p: float = 0.9,
184
  top_k: int = 50,
@@ -196,9 +148,6 @@ def generate(
196
 
197
  chat_history : list
198
  A list containing tuples representing the conversation history. Each tuple should consist of two elements: the user's message and the assistant's response.
199
-
200
- system_prompt : str
201
- The system prompt, if any, to be included in the conversation context.
202
 
203
  max_new_tokens : int, optional
204
  The maximum number of tokens to generate for the response (default is 1024).
@@ -228,10 +177,9 @@ def generate(
228
  global tokenizer
229
  global model
230
 
231
- conversation = preprocess_conversation(
232
  message=message,
233
- chat_history=chat_history,
234
- system_prompt=system_prompt
235
  )
236
 
237
  input_ids = tokenizer.apply_chat_template(
@@ -239,10 +187,6 @@ def generate(
239
  return_tensors="pt",
240
  add_generation_prompt=True
241
  )
242
- input_ids = trim_input_ids(
243
- input_ids=input_ids,
244
- max_length=MAX_INPUT_TOKEN_LENGTH
245
- )
246
 
247
  input_ids = input_ids.to(
248
  torch.device("cuda")
@@ -279,10 +223,16 @@ def generate(
279
 
280
  return "".join(outputs)
281
 
 
 
 
 
 
 
282
  chat_interface = gr.ChatInterface(
283
  fn=generate,
 
284
  additional_inputs=[
285
- gr.Textbox(label="System prompt", lines=6),
286
  gr.Slider(
287
  label="Max new tokens",
288
  minimum=1,
@@ -314,9 +264,9 @@ chat_interface = gr.ChatInterface(
314
  ],
315
  fill_height=True,
316
  examples=[
317
- ["implement snake game using pygame"],
318
  ["Can you explain briefly to me what is the Python programming language?"],
319
- ["write a program to find the factorial of a number"],
320
  ],
321
  )
322
 
 
87
  description=DESCRIPTION
88
  )
89
 
90
+ def format_prompt(
91
+ message,
92
+ history
93
+ ) -> str:
 
 
94
  """
95
+ Format a prompt for dialogue generation using historical conversation data.
96
 
97
  Parameters
98
  ----------
99
  message : str
100
+ The user's current message or prompt.
101
+
102
+ history : list of tuple
103
+ A list of tuples representing past interactions, where each tuple
104
+ contains a user prompt and a corresponding bot response.
 
 
105
 
106
  Returns
107
  -------
108
+ str
109
+ Formatted prompt for dialogue generation, including the user's current
110
+ message and historical conversation data.
111
+
112
+ Examples
113
+ --------
114
+ >>> message = "How are you?"
115
+ >>> history = [("Hi there!", "Hello!"), ("What's up?", "Not much.")]
116
+ >>> format_prompt(message, history)
117
+ '<s>[INST] Hi there! [/INST] Hello!</s> <s>[INST] What\'s up? [/INST] Not much.</s> <s>[INST] How are you? [/INST]'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  """
119
+ prompt = "<s>"
 
 
 
 
 
120
 
121
+ for user_prompt, bot_response in history:
122
+ prompt += f"[INST] {user_prompt} [/INST]"
123
+ prompt += f" {bot_response}</s> "
124
+ prompt += f"[INST] {message} [/INST]"
125
 
126
+ return prompt
 
 
 
 
 
 
 
 
 
127
 
128
 
129
  @spaces.GPU
130
  def generate(
131
  message: str,
132
  chat_history: list,
133
+ max_new_tokens: int = 2048,
 
134
  temperature: float = 0.6,
135
  top_p: float = 0.9,
136
  top_k: int = 50,
 
148
 
149
  chat_history : list
150
  A list containing tuples representing the conversation history. Each tuple should consist of two elements: the user's message and the assistant's response.
 
 
 
151
 
152
  max_new_tokens : int, optional
153
  The maximum number of tokens to generate for the response (default is 1024).
 
177
  global tokenizer
178
  global model
179
 
180
+ conversation = format_prompt(
181
  message=message,
182
+ history=history
 
183
  )
184
 
185
  input_ids = tokenizer.apply_chat_template(
 
187
  return_tensors="pt",
188
  add_generation_prompt=True
189
  )
 
 
 
 
190
 
191
  input_ids = input_ids.to(
192
  torch.device("cuda")
 
223
 
224
  return "".join(outputs)
225
 
226
+
227
+ chatbot = gr.Chatbot(
228
+ height=400,
229
+ show_copy_button=True
230
+ )
231
+
232
  chat_interface = gr.ChatInterface(
233
  fn=generate,
234
+ chatbot=chatbot,
235
  additional_inputs=[
 
236
  gr.Slider(
237
  label="Max new tokens",
238
  minimum=1,
 
264
  ],
265
  fill_height=True,
266
  examples=[
267
+ ["Implement snake game using pygame"],
268
  ["Can you explain briefly to me what is the Python programming language?"],
269
+ ["Write a program to find the factorial of a number"],
270
  ],
271
  )
272