BlueDice commited on
Commit
f7ed38a
·
1 Parent(s): c3eb599

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -11
handler.py CHANGED
@@ -16,22 +16,21 @@ class EndpointHandler():
16
  messages = request_inputs["messages"]
17
  char_name = request_inputs["char_name"]
18
  user_name = request_inputs["user_name"]
 
19
  template = self.default_template
20
- user_input = "\n".join([
21
  "{name}: {message}".format(
22
  name = char_name if (id["role"] == "AI") else user_name,
23
  message = id["message"].strip()
24
  ) for id in messages
25
- ])
26
- prompt = template.format(
27
- char_name = char_name,
28
- user_name = user_name,
29
- user_input = user_input
30
- )
31
- input_ids = self.tokenizer(
32
- prompt + f"\n{char_name}:",
33
- return_tensors = "pt"
34
- ).to("cuda")
35
  encoded_output = self.model.generate(
36
  input_ids["input_ids"],
37
  max_new_tokens = 50,
 
16
  messages = request_inputs["messages"]
17
  char_name = request_inputs["char_name"]
18
  user_name = request_inputs["user_name"]
19
+ chats_curled = request_inputs["chats_curled"]
20
  template = self.default_template
21
+ user_input = [
22
  "{name}: {message}".format(
23
  name = char_name if (id["role"] == "AI") else user_name,
24
  message = id["message"].strip()
25
  ) for id in messages
26
+ ]
27
+ while True:
28
+ prompt = template.format(char_name = char_name, user_name = user_name, user_input = "\n".join([user_input]))
29
+ input_ids = self.tokenizer(prompt + f"\n{char_name}:", return_tensors = "pt").to("cuda")
30
+ if input_ids.input_ids.size(1) > 2048:
31
+ chats_curled += 1
32
+ user_input = user_input[chats_curled*2:]
33
+ else: break
 
 
34
  encoded_output = self.model.generate(
35
  input_ids["input_ids"],
36
  max_new_tokens = 50,