Update handler.py
Browse files- handler.py +10 -11
handler.py
CHANGED
@@ -16,22 +16,21 @@ class EndpointHandler():
|
|
16 |
messages = request_inputs["messages"]
|
17 |
char_name = request_inputs["char_name"]
|
18 |
user_name = request_inputs["user_name"]
|
|
|
19 |
template = self.default_template
|
20 |
-
user_input =
|
21 |
"{name}: {message}".format(
|
22 |
name = char_name if (id["role"] == "AI") else user_name,
|
23 |
message = id["message"].strip()
|
24 |
) for id in messages
|
25 |
-
]
|
26 |
-
|
27 |
-
char_name = char_name,
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
return_tensors = "pt"
|
34 |
-
).to("cuda")
|
35 |
encoded_output = self.model.generate(
|
36 |
input_ids["input_ids"],
|
37 |
max_new_tokens = 50,
|
|
|
16 |
messages = request_inputs["messages"]
|
17 |
char_name = request_inputs["char_name"]
|
18 |
user_name = request_inputs["user_name"]
|
19 |
+
chats_curled = request_inputs["chats_curled"]
|
20 |
template = self.default_template
|
21 |
+
user_input = [
|
22 |
"{name}: {message}".format(
|
23 |
name = char_name if (id["role"] == "AI") else user_name,
|
24 |
message = id["message"].strip()
|
25 |
) for id in messages
|
26 |
+
]
|
27 |
+
while True:
|
28 |
+
prompt = template.format(char_name = char_name, user_name = user_name, user_input = "\n".join([user_input]))
|
29 |
+
input_ids = self.tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
|
30 |
+
if input_ids.input_ids.size(1) > 2048:
|
31 |
+
chats_curled += 1
|
32 |
+
user_input = user_input[chats_curled*2:]
|
33 |
+
else: break
|
|
|
|
|
34 |
encoded_output = self.model.generate(
|
35 |
input_ids["input_ids"],
|
36 |
max_new_tokens = 50,
|