|
import datetime |
|
import logging |
|
import time |
|
|
|
import requests |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class KoboldApiServerException(Exception): |
|
pass |
|
|
|
|
|
def wait_for_kai_server(koboldai_url: str, max_wait_time_seconds: int) -> None: |
|
'''Blocks until the KAI server is up.''' |
|
start_time = datetime.datetime.now() |
|
|
|
while True: |
|
try: |
|
requests.head(koboldai_url, timeout=(5, 5)) |
|
break |
|
except requests.exceptions.ConnectionError as ex: |
|
if "Connection refused" not in str(ex): |
|
raise ex |
|
|
|
abort_at = start_time + datetime.timedelta( |
|
seconds=max_wait_time_seconds) |
|
|
|
if datetime.datetime.now() > abort_at: |
|
raise TimeoutError( |
|
f"Waited for {max_wait_time_seconds} seconds but KoboldAI" |
|
" server is still not up, aborting.") |
|
|
|
time.sleep(1) |
|
|
|
|
|
def run_raw_inference_on_kai( |
|
koboldai_url: str, |
|
prompt: str, |
|
max_new_tokens: int, |
|
do_sample: bool, |
|
typical_p: float, |
|
repetition_penalty: float, |
|
**kwargs, |
|
) -> str: |
|
endpoint = f"{koboldai_url}/api/v1/generate" |
|
payload = { |
|
"prompt": prompt, |
|
|
|
|
|
|
|
"max_length": 32, |
|
|
|
|
|
|
|
"sampler_full_determinism": not do_sample, |
|
"typical": typical_p, |
|
"rep_pen": repetition_penalty, |
|
|
|
|
|
|
|
"frmttriminc": False, |
|
"frmtrmspch": False, |
|
"frmtrmblln": False, |
|
"frmtadsnsp": False, |
|
|
|
|
|
**kwargs, |
|
} |
|
generated_text = "" |
|
|
|
|
|
|
|
|
|
|
|
attempts = 0 |
|
max_extra_attempts = 4 |
|
while attempts < (payload["max_length"] / |
|
max_new_tokens) + max_extra_attempts: |
|
attempts += 1 |
|
response = requests.post(endpoint, json=payload) |
|
if not response.ok: |
|
error_message = response.text |
|
raise KoboldApiServerException( |
|
"The KoboldAI API server returned an error" |
|
f" (HTTP status code {response.status_code}): {error_message}") |
|
|
|
inference_result = response.json()["results"][0]["text"] |
|
generated_text += inference_result |
|
|
|
|
|
|
|
if "\nYou:" in generated_text: |
|
logger.debug("Hit `\nYou:`: `%s`", generated_text) |
|
return generated_text |
|
|
|
|
|
if generated_text.endswith("<|endoftext|>"): |
|
logger.debug("Got EOS token: `%s`", generated_text) |
|
|
|
|
|
|
|
return generated_text.replace("<|endoftext|>", "\nYou:") |
|
|
|
|
|
if len(generated_text.split()) >= max_new_tokens: |
|
logger.debug("Hit max length: `%s`", generated_text) |
|
return generated_text |
|
|
|
|
|
|
|
logger.debug("Got another %s tokens, but still not done: `%s`", |
|
payload["max_length"], generated_text) |
|
payload["prompt"] += inference_result |
|
|
|
logger.debug("Exhausted generation attempts: `%s`", generated_text) |
|
return generated_text |