Spaces:
Build error
Build error
import base64 | |
import json | |
import os | |
import time | |
import requests | |
import yaml | |
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
from threading import Thread | |
import numpy as np | |
from modules import shared | |
from modules.text_generation import encode, generate_reply | |
params = { | |
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, | |
} | |
debug = True if 'OPENEDAI_DEBUG' in os.environ else False | |
# Optional, install the module and download the model to enable | |
# v1/embeddings | |
try: | |
from sentence_transformers import SentenceTransformer | |
except ImportError: | |
pass | |
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" | |
embedding_model = None | |
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] | |
# little helper to get defaults if arg is present but None and should be the same type as default. | |
def default(dic, key, default): | |
val = dic.get(key, default) | |
if type(val) != type(default): | |
# maybe it's just something like 1 instead of 1.0 | |
try: | |
v = type(default)(val) | |
if type(val)(v) == val: # if it's the same value passed in, it's ok. | |
return v | |
except: | |
pass | |
val = default | |
return val | |
def clamp(value, minvalue, maxvalue): | |
return max(minvalue, min(value, maxvalue)) | |
def deduce_template(): | |
# Alpaca is verbose so a good default prompt | |
default_template = ( | |
"Below is an instruction that describes a task, paired with an input that provides further context. " | |
"Write a response that appropriately completes the request.\n\n" | |
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" | |
) | |
# Use the special instruction/input/response template for anything trained like Alpaca | |
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']: | |
return default_template | |
try: | |
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) | |
template = instruct['turn_template'] | |
template = template\ | |
.replace('<|user|>', instruct.get('user', ''))\ | |
.replace('<|bot|>', instruct.get('bot', ''))\ | |
.replace('<|user-message|>', '{instruction}\n{input}') | |
return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') | |
except: | |
return default_template | |
def float_list_to_base64(float_list): | |
# Convert the list to a float32 array that the OpenAPI client expects | |
float_array = np.array(float_list, dtype="float32") | |
# Get raw bytes | |
bytes_array = float_array.tobytes() | |
# Encode bytes into base64 | |
encoded_bytes = base64.b64encode(bytes_array) | |
# Turn raw base64 encoded bytes into ASCII | |
ascii_string = encoded_bytes.decode('ascii') | |
return ascii_string | |
class Handler(BaseHTTPRequestHandler): | |
def do_GET(self): | |
if self.path.startswith('/v1/models'): | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
# TODO: list all models and allow model changes via API? Lora's? | |
# This API should list capabilities, limits and pricing... | |
models = [{ | |
"id": shared.model_name, # The real chat/completions model | |
"object": "model", | |
"owned_by": "user", | |
"permission": [] | |
}, { | |
"id": st_model, # The real sentence transformer embeddings model | |
"object": "model", | |
"owned_by": "user", | |
"permission": [] | |
}, { # these are expected by so much, so include some here as a dummy | |
"id": "gpt-3.5-turbo", # /v1/chat/completions | |
"object": "model", | |
"owned_by": "user", | |
"permission": [] | |
}, { | |
"id": "text-curie-001", # /v1/completions, 2k context | |
"object": "model", | |
"owned_by": "user", | |
"permission": [] | |
}, { | |
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 | |
"object": "model", | |
"owned_by": "user", | |
"permission": [] | |
}] | |
response = '' | |
if self.path == '/v1/models': | |
response = json.dumps({ | |
"object": "list", | |
"data": models, | |
}) | |
else: | |
the_model_name = self.path[len('/v1/models/'):] | |
response = json.dumps({ | |
"id": the_model_name, | |
"object": "model", | |
"owned_by": "user", | |
"permission": [] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
self.send_error(404) | |
def do_POST(self): | |
if debug: | |
print(self.headers) # did you know... python-openai sends your linux kernel & python version? | |
content_length = int(self.headers['Content-Length']) | |
body = json.loads(self.rfile.read(content_length).decode('utf-8')) | |
if debug: | |
print(body) | |
if '/completions' in self.path or '/generate' in self.path: | |
is_legacy = '/generate' in self.path | |
is_chat = 'chat' in self.path | |
resp_list = 'data' if is_legacy else 'choices' | |
# XXX model is ignored for now | |
# model = body.get('model', shared.model_name) # ignored, use existing for now | |
model = shared.model_name | |
created_time = int(time.time()) | |
cmpl_id = "conv-%d" % (created_time) | |
# Try to use openai defaults or map them to something with the same intent | |
stopping_strings = default(shared.settings, 'custom_stopping_strings', []) | |
if 'stop' in body: | |
if isinstance(body['stop'], str): | |
stopping_strings = [body['stop']] | |
elif isinstance(body['stop'], list): | |
stopping_strings = body['stop'] | |
truncation_length = default(shared.settings, 'truncation_length', 2048) | |
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) | |
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it. | |
max_tokens_str = 'length' if is_legacy else 'max_tokens' | |
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) | |
# hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max | |
while truncation_length <= max_tokens: | |
max_tokens = max_tokens // 2 | |
req_params = { | |
'max_new_tokens': max_tokens, | |
'temperature': default(body, 'temperature', 1.0), | |
'top_p': default(body, 'top_p', 1.0), | |
'top_k': default(body, 'best_of', 1), | |
# XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2 | |
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it. | |
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better. | |
# XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it. | |
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0, | |
'suffix': body.get('suffix', None), | |
'stream': default(body, 'stream', False), | |
'echo': default(body, 'echo', False), | |
##################################################### | |
'seed': shared.settings.get('seed', -1), | |
# int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map | |
# unofficial, but it needs to get set anyways. | |
'truncation_length': truncation_length, | |
# no more args. | |
'add_bos_token': shared.settings.get('add_bos_token', True), | |
'do_sample': True, | |
'typical_p': 1.0, | |
'min_length': 0, | |
'no_repeat_ngram_size': 0, | |
'num_beams': 1, | |
'penalty_alpha': 0.0, | |
'length_penalty': 1, | |
'early_stopping': False, | |
'ban_eos_token': False, | |
'skip_special_tokens': True, | |
} | |
# fixup absolute 0.0's | |
for par in ['temperature', 'repetition_penalty', 'encoder_repetition_penalty']: | |
req_params[par] = clamp(req_params[par], 0.001, 1.999) | |
self.send_response(200) | |
if req_params['stream']: | |
self.send_header('Content-Type', 'text/event-stream') | |
self.send_header('Cache-Control', 'no-cache') | |
# self.send_header('Connection', 'keep-alive') | |
else: | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
token_count = 0 | |
completion_token_count = 0 | |
prompt = '' | |
stream_object_type = '' | |
object_type = '' | |
if is_chat: | |
stream_object_type = 'chat.completions.chunk' | |
object_type = 'chat.completions' | |
messages = body['messages'] | |
system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} | |
if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this. | |
system_msg = body['prompt'] | |
chat_msgs = [] | |
for m in messages: | |
role = m['role'] | |
content = m['content'] | |
# name = m.get('name', 'user') | |
if role == 'system': | |
system_msg += content | |
else: | |
chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed? | |
system_token_count = len(encode(system_msg)[0]) | |
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count | |
chat_msg = '' | |
while chat_msgs: | |
new_msg = chat_msgs.pop() | |
new_size = len(encode(new_msg)[0]) | |
if new_size <= remaining_tokens: | |
chat_msg = new_msg + chat_msg | |
remaining_tokens -= new_size | |
else: | |
# TODO: clip a message to fit? | |
# ie. user: ...<clipped message> | |
break | |
if len(chat_msgs) > 0: | |
print(f"truncating chat messages, dropping {len(chat_msgs)} messages.") | |
if system_msg: | |
prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: ' | |
else: | |
prompt = chat_msg + '\nassistant: ' | |
token_count = len(encode(prompt)[0]) | |
# pass with some expected stop strings. | |
# some strange cases of "##| Instruction: " sneaking through. | |
stopping_strings += standard_stopping_strings | |
req_params['custom_stopping_strings'] = stopping_strings | |
else: | |
stream_object_type = 'text_completion.chunk' | |
object_type = 'text_completion' | |
# ... encoded as a string, array of strings, array of tokens, or array of token arrays. | |
if is_legacy: | |
prompt = body['context'] # Older engines.generate API | |
else: | |
prompt = body['prompt'] # XXX this can be different types | |
if isinstance(prompt, list): | |
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls? | |
token_count = len(encode(prompt)[0]) | |
if token_count >= req_params['truncation_length']: | |
new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count) | |
prompt = prompt[-new_len:] | |
print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.") | |
# pass with some expected stop strings. | |
# some strange cases of "##| Instruction: " sneaking through. | |
stopping_strings += standard_stopping_strings | |
req_params['custom_stopping_strings'] = stopping_strings | |
if req_params['stream']: | |
shared.args.chat = True | |
# begin streaming | |
chunk = { | |
"id": cmpl_id, | |
"object": stream_object_type, | |
"created": created_time, | |
"model": shared.model_name, | |
resp_list: [{ | |
"index": 0, | |
"finish_reason": None, | |
}], | |
} | |
if stream_object_type == 'text_completion.chunk': | |
chunk[resp_list][0]["text"] = "" | |
else: | |
# This is coming back as "system" to the openapi cli, not sure why. | |
# So yeah... do both methods? delta and messages. | |
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} | |
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} | |
# { "role": "assistant" } | |
response = 'data: ' + json.dumps(chunk) + '\n' | |
self.wfile.write(response.encode('utf-8')) | |
# generate reply ####################################### | |
if debug: | |
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) | |
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) | |
answer = '' | |
seen_content = '' | |
longest_stop_len = max([len(x) for x in stopping_strings]) | |
for a in generator: | |
answer = a | |
stop_string_found = False | |
len_seen = len(seen_content) | |
search_start = max(len_seen - longest_stop_len, 0) | |
for string in stopping_strings: | |
idx = answer.find(string, search_start) | |
if idx != -1: | |
answer = answer[:idx] # clip it. | |
stop_string_found = True | |
if stop_string_found: | |
break | |
# If something like "\nYo" is generated just before "\nYou:" | |
# is completed, buffer and generate more, don't send it | |
buffer_and_continue = False | |
for string in stopping_strings: | |
for j in range(len(string) - 1, 0, -1): | |
if answer[-j:] == string[:j]: | |
buffer_and_continue = True | |
break | |
else: | |
continue | |
break | |
if buffer_and_continue: | |
continue | |
if req_params['stream']: | |
# Streaming | |
new_content = answer[len_seen:] | |
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. | |
continue | |
seen_content = answer | |
chunk = { | |
"id": cmpl_id, | |
"object": stream_object_type, | |
"created": created_time, | |
"model": shared.model_name, | |
resp_list: [{ | |
"index": 0, | |
"finish_reason": None, | |
}], | |
} | |
if stream_object_type == 'text_completion.chunk': | |
chunk[resp_list][0]['text'] = new_content | |
else: | |
# So yeah... do both methods? delta and messages. | |
chunk[resp_list][0]['message'] = {'content': new_content} | |
chunk[resp_list][0]['delta'] = {'content': new_content} | |
response = 'data: ' + json.dumps(chunk) + '\n' | |
self.wfile.write(response.encode('utf-8')) | |
completion_token_count += len(encode(new_content)[0]) | |
if req_params['stream']: | |
chunk = { | |
"id": cmpl_id, | |
"object": stream_object_type, | |
"created": created_time, | |
"model": model, # TODO: add Lora info? | |
resp_list: [{ | |
"index": 0, | |
"finish_reason": "stop", | |
}], | |
"usage": { | |
"prompt_tokens": token_count, | |
"completion_tokens": completion_token_count, | |
"total_tokens": token_count + completion_token_count | |
} | |
} | |
if stream_object_type == 'text_completion.chunk': | |
chunk[resp_list][0]['text'] = '' | |
else: | |
# So yeah... do both methods? delta and messages. | |
chunk[resp_list][0]['message'] = {'content': ''} | |
chunk[resp_list][0]['delta'] = {} | |
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n' | |
self.wfile.write(response.encode('utf-8')) | |
# Finished if streaming. | |
if debug: | |
print({'response': answer}) | |
return | |
if debug: | |
print({'response': answer}) | |
completion_token_count = len(encode(answer)[0]) | |
stop_reason = "stop" | |
if token_count + completion_token_count >= req_params['truncation_length']: | |
stop_reason = "length" | |
resp = { | |
"id": cmpl_id, | |
"object": object_type, | |
"created": created_time, | |
"model": model, # TODO: add Lora info? | |
resp_list: [{ | |
"index": 0, | |
"finish_reason": stop_reason, | |
}], | |
"usage": { | |
"prompt_tokens": token_count, | |
"completion_tokens": completion_token_count, | |
"total_tokens": token_count + completion_token_count | |
} | |
} | |
if is_chat: | |
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer} | |
else: | |
resp[resp_list][0]["text"] = answer | |
response = json.dumps(resp) | |
self.wfile.write(response.encode('utf-8')) | |
elif '/edits' in self.path: | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
created_time = int(time.time()) | |
# Using Alpaca format, this may work with other models too. | |
instruction = body['instruction'] | |
input = body.get('input', '') | |
instruction_template = deduce_template() | |
edit_task = instruction_template.format(instruction=instruction, input=input) | |
truncation_length = default(shared.settings, 'truncation_length', 2048) | |
token_count = len(encode(edit_task)[0]) | |
max_tokens = truncation_length - token_count | |
req_params = { | |
'max_new_tokens': max_tokens, | |
'temperature': clamp(default(body, 'temperature', 1.0), 0.001, 1.999), | |
'top_p': clamp(default(body, 'top_p', 1.0), 0.001, 1.0), | |
'top_k': 1, | |
'repetition_penalty': 1.18, | |
'encoder_repetition_penalty': 1.0, | |
'suffix': None, | |
'stream': False, | |
'echo': False, | |
'seed': shared.settings.get('seed', -1), | |
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map | |
'truncation_length': truncation_length, | |
'add_bos_token': shared.settings.get('add_bos_token', True), | |
'do_sample': True, | |
'typical_p': 1.0, | |
'min_length': 0, | |
'no_repeat_ngram_size': 0, | |
'num_beams': 1, | |
'penalty_alpha': 0.0, | |
'length_penalty': 1, | |
'early_stopping': False, | |
'ban_eos_token': False, | |
'skip_special_tokens': True, | |
'custom_stopping_strings': [], | |
} | |
if debug: | |
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) | |
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False) | |
answer = '' | |
for a in generator: | |
answer = a | |
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply. | |
if edit_task[-1] != '\n' and answer and answer[0] == ' ': | |
answer = answer[1:] | |
completion_token_count = len(encode(answer)[0]) | |
resp = { | |
"object": "edit", | |
"created": created_time, | |
"choices": [{ | |
"text": answer, | |
"index": 0, | |
}], | |
"usage": { | |
"prompt_tokens": token_count, | |
"completion_tokens": completion_token_count, | |
"total_tokens": token_count + completion_token_count | |
} | |
} | |
if debug: | |
print({'answer': answer, 'completion_token_count': completion_token_count}) | |
response = json.dumps(resp) | |
self.wfile.write(response.encode('utf-8')) | |
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: | |
# Stable Diffusion callout wrapper for txt2img | |
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E | |
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings. | |
# If you want high quality tailored results you should just use the Stable Diffusion API directly. | |
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc, | |
# Will probably work best with the stock SD models. | |
# SD configuration is beyond the scope of this API. | |
# At this point I will not add the edits and variations endpoints (ie. img2img) because they | |
# require changing the form data handling to accept multipart form data, also to properly support | |
# url return types will require file management and a web serving files... Perhaps later! | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size | |
response_format = default(body, 'response_format', 'url') # or b64_json | |
payload = { | |
'prompt': body['prompt'], # ignore prompt limit of 1000 characters | |
'width': width, | |
'height': height, | |
'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10 | |
} | |
resp = { | |
'created': int(time.time()), | |
'data': [] | |
} | |
# TODO: support SD_WEBUI_AUTH username:password pair. | |
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img" | |
response = requests.post(url=sd_url, json=payload) | |
r = response.json() | |
# r['parameters']... | |
for b64_json in r['images']: | |
if response_format == 'b64_json': | |
resp['data'].extend([{'b64_json': b64_json}]) | |
else: | |
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this | |
response = json.dumps(resp) | |
self.wfile.write(response.encode('utf-8')) | |
elif '/embeddings' in self.path and embedding_model is not None: | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
input = body['input'] if 'input' in body else body['text'] | |
if type(input) is str: | |
input = [input] | |
embeddings = embedding_model.encode(input).tolist() | |
def enc_emb(emb): | |
# If base64 is specified, encode. Otherwise, do nothing. | |
if body.get("encoding_format", "") == "base64": | |
return float_list_to_base64(emb) | |
else: | |
return emb | |
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)] | |
response = json.dumps({ | |
"object": "list", | |
"data": data, | |
"model": st_model, # return the real model | |
"usage": { | |
"prompt_tokens": 0, | |
"total_tokens": 0, | |
} | |
}) | |
if debug: | |
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") | |
self.wfile.write(response.encode('utf-8')) | |
elif '/moderations' in self.path: | |
# for now do nothing, just don't error. | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
response = json.dumps({ | |
"id": "modr-5MWoLO", | |
"model": "text-moderation-001", | |
"results": [{ | |
"categories": { | |
"hate": False, | |
"hate/threatening": False, | |
"self-harm": False, | |
"sexual": False, | |
"sexual/minors": False, | |
"violence": False, | |
"violence/graphic": False | |
}, | |
"category_scores": { | |
"hate": 0.0, | |
"hate/threatening": 0.0, | |
"self-harm": 0.0, | |
"sexual": 0.0, | |
"sexual/minors": 0.0, | |
"violence": 0.0, | |
"violence/graphic": 0.0 | |
}, | |
"flagged": False | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/token-count': | |
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side. | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
tokens = encode(body['prompt'])[0] | |
response = json.dumps({ | |
'results': [{ | |
'tokens': len(tokens) | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
print(self.path, self.headers) | |
self.send_error(404) | |
def run_server(): | |
global embedding_model | |
try: | |
embedding_model = SentenceTransformer(st_model) | |
print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}") | |
except: | |
print(f"\nFailed to load embedding model: {st_model}") | |
pass | |
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) | |
server = ThreadingHTTPServer(server_addr, Handler) | |
if shared.args.share: | |
try: | |
from flask_cloudflared import _run_cloudflared | |
public_url = _run_cloudflared(params['port'], params['port'] + 1) | |
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1') | |
except ImportError: | |
print('You should install flask_cloudflared manually') | |
else: | |
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') | |
server.serve_forever() | |
def setup(): | |
Thread(target=run_server, daemon=True).start() | |