louisbrulenaudet commited on
Commit
b1dc71c
1 Parent(s): 6ba5195

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -28
app.py CHANGED
@@ -87,43 +87,77 @@ model, tokenizer, description = setup(
87
  description=DESCRIPTION
88
  )
89
 
90
- def format_prompt(
91
- message,
92
- history
93
- ) -> str:
94
  """
95
- Format a prompt for dialogue generation using historical conversation data.
96
 
97
  Parameters
98
  ----------
99
  message : str
100
- The user's current message or prompt.
101
-
102
- history : list of tuple
103
- A list of tuples representing past interactions, where each tuple
104
- contains a user prompt and a corresponding bot response.
105
 
106
  Returns
107
  -------
108
- str
109
- Formatted prompt for dialogue generation, including the user's current
110
- message and historical conversation data.
111
-
112
- Examples
113
- --------
114
- >>> message = "How are you?"
115
- >>> history = [("Hi there!", "Hello!"), ("What's up?", "Not much.")]
116
- >>> format_prompt(message, history)
117
- '<s>[INST] Hi there! [/INST] Hello!</s> <s>[INST] What\'s up? [/INST] Not much.</s> <s>[INST] How are you? [/INST]'
118
  """
119
- prompt = "<s>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- for user_prompt, bot_response in history:
122
- prompt += f"[INST] {user_prompt} [/INST]"
123
- prompt += f" {bot_response}</s> "
124
- prompt += f"[INST] {message} [/INST]"
125
 
126
- return prompt
 
 
 
 
 
 
 
 
 
127
 
128
 
129
  @spaces.GPU
@@ -177,9 +211,9 @@ def generate(
177
  global tokenizer
178
  global model
179
 
180
- conversation = format_prompt(
181
  message=message,
182
- history=history
183
  )
184
 
185
  input_ids = tokenizer.apply_chat_template(
@@ -187,6 +221,11 @@ def generate(
187
  return_tensors="pt",
188
  add_generation_prompt=True
189
  )
 
 
 
 
 
190
 
191
  input_ids = input_ids.to(
192
  torch.device("cuda")
@@ -214,6 +253,7 @@ def generate(
214
  target=model.generate,
215
  kwargs=generate_kwargs
216
  )
 
217
  t.start()
218
 
219
  outputs = []
 
87
  description=DESCRIPTION
88
  )
89
 
90
+ def preprocess_conversation(
91
+ message: str,
92
+ history: list,
93
+ ):
94
  """
95
+ Preprocess the conversation history by formatting it appropriately.
96
 
97
  Parameters
98
  ----------
99
  message : str
100
+ The user's message.
101
+
102
+ history : list
103
+ The conversation history, where each element is a tuple (user_message, assistant_response).
 
104
 
105
  Returns
106
  -------
107
+ list
108
+ The formatted conversation history.
 
 
 
 
 
 
 
 
109
  """
110
+ conversation = []
111
+
112
+ for user, assistant in history:
113
+ conversation.extend(
114
+ [
115
+ {
116
+ "role": "user",
117
+ "content": user
118
+ },
119
+ {
120
+ "role": "assistant",
121
+ "content": assistant
122
+ }
123
+ ]
124
+ )
125
+
126
+ conversation.append(
127
+ {
128
+ "role": "user",
129
+ "content": message
130
+ }
131
+ )
132
+
133
+ return conversation
134
+
135
+
136
+ def trim_input_ids(
137
+ input_ids,
138
+ max_length
139
+ ):
140
+ """
141
+ Trim the input token IDs if they exceed the maximum length.
142
+
143
+ Parameters
144
+ ----------
145
+ input_ids : torch.Tensor
146
+ The input token IDs.
147
 
148
+ max_length : int
149
+ The maximum length allowed.
 
 
150
 
151
+ Returns
152
+ -------
153
+ torch.Tensor
154
+ The trimmed input token IDs.
155
+ """
156
+ if input_ids.shape[1] > max_length:
157
+ input_ids = input_ids[:, -max_length:]
158
+ print(f"Trimmed input from conversation as it was longer than {max_length} tokens.")
159
+
160
+ return input_ids
161
 
162
 
163
  @spaces.GPU
 
211
  global tokenizer
212
  global model
213
 
214
+ conversation = preprocess_conversation(
215
  message=message,
216
+ history=history,
217
  )
218
 
219
  input_ids = tokenizer.apply_chat_template(
 
221
  return_tensors="pt",
222
  add_generation_prompt=True
223
  )
224
+
225
+ input_ids = trim_input_ids(
226
+ input_ids=input_ids,
227
+ max_length=MAX_INPUT_TOKEN_LENGTH
228
+ )
229
 
230
  input_ids = input_ids.to(
231
  torch.device("cuda")
 
253
  target=model.generate,
254
  kwargs=generate_kwargs
255
  )
256
+
257
  t.start()
258
 
259
  outputs = []