petalschatlvn / websocket_api.py
lavanjv's picture
Upload 7 files
2b58075
import json
from traceback import format_exc
import flask_sock
import hivemind
import torch
import config
from app import sock, models
from utils import safe_decode
logger = hivemind.get_logger(__file__)
@sock.route("/api/v2/generate")
def ws_api_generate(ws):
try:
request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT))
assert request["type"] == "open_inference_session"
model_name = request.get("model")
if model_name is None:
model_name = config.DEFAULT_MODEL_NAME
logger.info(f"ws.generate.open(), model={repr(model_name)}, max_length={repr(request['max_length'])}")
model, tokenizer = models[model_name]
with model.inference_session(max_length=request["max_length"]) as session:
ws.send(json.dumps({"ok": True}))
while True:
request = json.loads(ws.receive(timeout=config.STEP_TIMEOUT))
assert request["type"] == "generate"
inputs = request.get("inputs")
logger.info(f"ws.generate.step(), inputs={repr(inputs)}")
if inputs is not None:
inputs = tokenizer(inputs, return_tensors="pt")["input_ids"].to(config.DEVICE)
n_input_tokens = inputs.shape[1]
else:
n_input_tokens = 0
stop_sequence = request.get("stop_sequence")
extra_stop_sequences = request.get("extra_stop_sequences")
if extra_stop_sequences is not None:
cont_token = tokenizer(stop_sequence, return_tensors="pt")["input_ids"].to(config.DEVICE)
assert cont_token.shape == (1, 1), \
"extra_stop_sequences require stop_sequence length to be exactly 1 token"
all_outputs = ''
delta_q = []
stop = False
while not stop:
outputs = model.generate(
inputs=inputs,
do_sample=request.get("do_sample", False),
temperature=request.get("temperature", 1.0),
top_k=request.get("top_k"),
top_p=request.get("top_p"),
max_length=request.get("max_length"),
max_new_tokens=request.get("max_new_tokens"),
session=session,
)
delta = outputs[0, n_input_tokens:].tolist()
outputs = safe_decode(tokenizer, torch.Tensor(delta_q + delta))
inputs = None # Inputs are passed only for the 1st token of the bot's response
n_input_tokens = 0
combined = all_outputs + outputs
stop = stop_sequence is None or combined.endswith(stop_sequence)
if extra_stop_sequences is not None:
for seq in extra_stop_sequences:
if combined.endswith(seq):
stop = True
session.last_token_id = cont_token
if not stop and outputs[-10:].find(u'\ufffd') > -1:
# If there's a replacement character, keep getting more tokens
# until we can decode properly
delta_q = delta_q + delta
logger.info(f"ws.generate.append_retry(), all_outputs={repr(combined)}")
else:
all_outputs = combined
delta_q = []
logger.info(f"ws.generate.step(), all_outputs={repr(all_outputs)}, stop={stop}")
ws.send(json.dumps({"ok": True, "outputs": outputs, "stop": stop}))
except flask_sock.ConnectionClosed:
pass
except Exception:
logger.warning("ws.generate failed:", exc_info=True)
ws.send(json.dumps({"ok": False, "traceback": format_exc()}))
finally:
logger.info(f"ws.generate.close()")