r3aperdev commited on
Commit
2382cac
·
1 Parent(s): d03cba8

koboldai_client.py

Browse files
Files changed (1) hide show
  1. koboldai_client.py +117 -0
koboldai_client.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+
5
+ import requests
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class KoboldApiServerException(Exception):
11
+ pass
12
+
13
+
14
+ def wait_for_kai_server(koboldai_url: str, max_wait_time_seconds: int) -> None:
15
+ '''Blocks until the KAI server is up.'''
16
+ start_time = datetime.datetime.now()
17
+
18
+ while True:
19
+ try:
20
+ requests.head(koboldai_url, timeout=(5, 5))
21
+ break
22
+ except requests.exceptions.ConnectionError as ex:
23
+ if "Connection refused" not in str(ex):
24
+ raise ex
25
+
26
+ abort_at = start_time + datetime.timedelta(
27
+ seconds=max_wait_time_seconds)
28
+
29
+ if datetime.datetime.now() > abort_at:
30
+ raise TimeoutError(
31
+ f"Waited for {max_wait_time_seconds} seconds but KoboldAI"
32
+ " server is still not up, aborting.")
33
+
34
+ time.sleep(1)
35
+
36
+
37
+ def run_raw_inference_on_kai(
38
+ koboldai_url: str,
39
+ prompt: str,
40
+ max_new_tokens: int,
41
+ do_sample: bool,
42
+ typical_p: float,
43
+ repetition_penalty: float,
44
+ **kwargs,
45
+ ) -> str:
46
+ endpoint = f"{koboldai_url}/api/v1/generate"
47
+ payload = {
48
+ "prompt": prompt,
49
+
50
+ # Incredibly low max len for reasons explained in the "while True" loop
51
+ # below.
52
+ "max_length": 32,
53
+
54
+ # Take care of parameters which are named differently between Kobold and
55
+ # HuggingFace.
56
+ "sampler_full_determinism": not do_sample,
57
+ "typical": typical_p,
58
+ "rep_pen": repetition_penalty,
59
+
60
+ # Disable any pre or post-processing on the KoboldAI side, we'd rather
61
+ # take care of things on our own.
62
+ "frmttriminc": False,
63
+ "frmtrmspch": False,
64
+ "frmtrmblln": False,
65
+ "frmtadsnsp": False,
66
+
67
+ # Append any other generation parameters that we didn't handle manually.
68
+ **kwargs,
69
+ }
70
+ generated_text = ""
71
+
72
+ # Currently, Kobold doesn't support custom stopping criteria, and their chat
73
+ # mode can't handle multi-line responses. To work around both of those, we
74
+ # use the regular adventure mode generation but keep asking for more tokens
75
+ # until the model starts trying to talk as the user, then we stop.
76
+ attempts = 0
77
+ max_extra_attempts = 4
78
+ while attempts < (payload["max_length"] /
79
+ max_new_tokens) + max_extra_attempts:
80
+ attempts += 1
81
+ response = requests.post(endpoint, json=payload)
82
+ if not response.ok:
83
+ error_message = response.text
84
+ raise KoboldApiServerException(
85
+ "The KoboldAI API server returned an error"
86
+ f" (HTTP status code {response.status_code}): {error_message}")
87
+
88
+ inference_result = response.json()["results"][0]["text"]
89
+ generated_text += inference_result
90
+
91
+ # Model started to talk as us. Stop generating and return results, the
92
+ # rest of the code will take care of trimming it properly.
93
+ if "\nYou:" in generated_text:
94
+ logger.debug("Hit `\nYou:`: `%s`", generated_text)
95
+ return generated_text
96
+
97
+ # For SFT: hit an EOS token. Trim and return.
98
+ if generated_text.endswith("<|endoftext|>"):
99
+ logger.debug("Got EOS token: `%s`", generated_text)
100
+
101
+ # We add a fake generated "\nYou:" here so the trimming code doesn't
102
+ # need to handle SFT and UFT models differently.
103
+ return generated_text.replace("<|endoftext|>", "\nYou:")
104
+
105
+ # Hit the configured generation limit.
106
+ if len(generated_text.split()) >= max_new_tokens:
107
+ logger.debug("Hit max length: `%s`", generated_text)
108
+ return generated_text
109
+
110
+ # Model still hasn't finished what it had to say. Append its output to
111
+ # the prompt and feed it back in.
112
+ logger.debug("Got another %s tokens, but still not done: `%s`",
113
+ payload["max_length"], generated_text)
114
+ payload["prompt"] += inference_result
115
+
116
+ logger.debug("Exhausted generation attempts: `%s`", generated_text)
117
+ return generated_text