import base64 import json import os import time import requests import yaml from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread import numpy as np from modules import shared from modules.text_generation import encode, generate_reply params = { 'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, } debug = True if 'OPENEDAI_DEBUG' in os.environ else False # Optional, install the module and download the model to enable # v1/embeddings try: from sentence_transformers import SentenceTransformer except ImportError: pass st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" embedding_model = None standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] # little helper to get defaults if arg is present but None and should be the same type as default. def default(dic, key, default): val = dic.get(key, default) if type(val) != type(default): # maybe it's just something like 1 instead of 1.0 try: v = type(default)(val) if type(val)(v) == val: # if it's the same value passed in, it's ok. return v except: pass val = default return val def clamp(value, minvalue, maxvalue): return max(minvalue, min(value, maxvalue)) def deduce_template(): # Alpaca is verbose so a good default prompt default_template = ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" ) # Use the special instruction/input/response template for anything trained like Alpaca if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']: return default_template try: instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) template = instruct['turn_template'] template = template\ .replace('<|user|>', instruct.get('user', ''))\ .replace('<|bot|>', instruct.get('bot', ''))\ .replace('<|user-message|>', '{instruction}\n{input}') return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') except: return default_template def float_list_to_base64(float_list): # Convert the list to a float32 array that the OpenAPI client expects float_array = np.array(float_list, dtype="float32") # Get raw bytes bytes_array = float_array.tobytes() # Encode bytes into base64 encoded_bytes = base64.b64encode(bytes_array) # Turn raw base64 encoded bytes into ASCII ascii_string = encoded_bytes.decode('ascii') return ascii_string class Handler(BaseHTTPRequestHandler): def do_GET(self): if self.path.startswith('/v1/models'): self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() # TODO: list all models and allow model changes via API? Lora's? # This API should list capabilities, limits and pricing... models = [{ "id": shared.model_name, # The real chat/completions model "object": "model", "owned_by": "user", "permission": [] }, { "id": st_model, # The real sentence transformer embeddings model "object": "model", "owned_by": "user", "permission": [] }, { # these are expected by so much, so include some here as a dummy "id": "gpt-3.5-turbo", # /v1/chat/completions "object": "model", "owned_by": "user", "permission": [] }, { "id": "text-curie-001", # /v1/completions, 2k context "object": "model", "owned_by": "user", "permission": [] }, { "id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 "object": "model", "owned_by": "user", "permission": [] }] response = '' if self.path == '/v1/models': response = json.dumps({ "object": "list", "data": models, }) else: the_model_name = self.path[len('/v1/models/'):] response = json.dumps({ "id": the_model_name, "object": "model", "owned_by": "user", "permission": [] }) self.wfile.write(response.encode('utf-8')) else: self.send_error(404) def do_POST(self): if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version? content_length = int(self.headers['Content-Length']) body = json.loads(self.rfile.read(content_length).decode('utf-8')) if debug: print(body) if '/completions' in self.path or '/generate' in self.path: is_legacy = '/generate' in self.path is_chat = 'chat' in self.path resp_list = 'data' if is_legacy else 'choices' # XXX model is ignored for now # model = body.get('model', shared.model_name) # ignored, use existing for now model = shared.model_name created_time = int(time.time()) cmpl_id = "conv-%d" % (created_time) # Try to use openai defaults or map them to something with the same intent stopping_strings = default(shared.settings, 'custom_stopping_strings', []) if 'stop' in body: if isinstance(body['stop'], str): stopping_strings = [body['stop']] elif isinstance(body['stop'], list): stopping_strings = body['stop'] truncation_length = default(shared.settings, 'truncation_length', 2048) truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it. max_tokens_str = 'length' if is_legacy else 'max_tokens' max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) # hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max while truncation_length <= max_tokens: max_tokens = max_tokens // 2 req_params = { 'max_new_tokens': max_tokens, 'temperature': default(body, 'temperature', 1.0), 'top_p': default(body, 'top_p', 1.0), 'top_k': default(body, 'best_of', 1), # XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2 # 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it. 'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better. # XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it. 'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0, 'suffix': body.get('suffix', None), 'stream': default(body, 'stream', False), 'echo': default(body, 'echo', False), ##################################################### 'seed': shared.settings.get('seed', -1), # int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map # unofficial, but it needs to get set anyways. 'truncation_length': truncation_length, # no more args. 'add_bos_token': shared.settings.get('add_bos_token', True), 'do_sample': True, 'typical_p': 1.0, 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, 'penalty_alpha': 0.0, 'length_penalty': 1, 'early_stopping': False, 'ban_eos_token': False, 'skip_special_tokens': True, } # fixup absolute 0.0's for par in ['temperature', 'repetition_penalty', 'encoder_repetition_penalty']: req_params[par] = clamp(req_params[par], 0.001, 1.999) self.send_response(200) if req_params['stream']: self.send_header('Content-Type', 'text/event-stream') self.send_header('Cache-Control', 'no-cache') # self.send_header('Connection', 'keep-alive') else: self.send_header('Content-Type', 'application/json') self.end_headers() token_count = 0 completion_token_count = 0 prompt = '' stream_object_type = '' object_type = '' if is_chat: stream_object_type = 'chat.completions.chunk' object_type = 'chat.completions' messages = body['messages'] system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this. system_msg = body['prompt'] chat_msgs = [] for m in messages: role = m['role'] content = m['content'] # name = m.get('name', 'user') if role == 'system': system_msg += content else: chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed? system_token_count = len(encode(system_msg)[0]) remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count chat_msg = '' while chat_msgs: new_msg = chat_msgs.pop() new_size = len(encode(new_msg)[0]) if new_size <= remaining_tokens: chat_msg = new_msg + chat_msg remaining_tokens -= new_size else: # TODO: clip a message to fit? # ie. user: ... break if len(chat_msgs) > 0: print(f"truncating chat messages, dropping {len(chat_msgs)} messages.") if system_msg: prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: ' else: prompt = chat_msg + '\nassistant: ' token_count = len(encode(prompt)[0]) # pass with some expected stop strings. # some strange cases of "##| Instruction: " sneaking through. stopping_strings += standard_stopping_strings req_params['custom_stopping_strings'] = stopping_strings else: stream_object_type = 'text_completion.chunk' object_type = 'text_completion' # ... encoded as a string, array of strings, array of tokens, or array of token arrays. if is_legacy: prompt = body['context'] # Older engines.generate API else: prompt = body['prompt'] # XXX this can be different types if isinstance(prompt, list): prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls? token_count = len(encode(prompt)[0]) if token_count >= req_params['truncation_length']: new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count) prompt = prompt[-new_len:] print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.") # pass with some expected stop strings. # some strange cases of "##| Instruction: " sneaking through. stopping_strings += standard_stopping_strings req_params['custom_stopping_strings'] = stopping_strings if req_params['stream']: shared.args.chat = True # begin streaming chunk = { "id": cmpl_id, "object": stream_object_type, "created": created_time, "model": shared.model_name, resp_list: [{ "index": 0, "finish_reason": None, }], } if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]["text"] = "" else: # This is coming back as "system" to the openapi cli, not sure why. # So yeah... do both methods? delta and messages. chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} # { "role": "assistant" } response = 'data: ' + json.dumps(chunk) + '\n' self.wfile.write(response.encode('utf-8')) # generate reply ####################################### if debug: print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) answer = '' seen_content = '' longest_stop_len = max([len(x) for x in stopping_strings]) for a in generator: answer = a stop_string_found = False len_seen = len(seen_content) search_start = max(len_seen - longest_stop_len, 0) for string in stopping_strings: idx = answer.find(string, search_start) if idx != -1: answer = answer[:idx] # clip it. stop_string_found = True if stop_string_found: break # If something like "\nYo" is generated just before "\nYou:" # is completed, buffer and generate more, don't send it buffer_and_continue = False for string in stopping_strings: for j in range(len(string) - 1, 0, -1): if answer[-j:] == string[:j]: buffer_and_continue = True break else: continue break if buffer_and_continue: continue if req_params['stream']: # Streaming new_content = answer[len_seen:] if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. continue seen_content = answer chunk = { "id": cmpl_id, "object": stream_object_type, "created": created_time, "model": shared.model_name, resp_list: [{ "index": 0, "finish_reason": None, }], } if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]['text'] = new_content else: # So yeah... do both methods? delta and messages. chunk[resp_list][0]['message'] = {'content': new_content} chunk[resp_list][0]['delta'] = {'content': new_content} response = 'data: ' + json.dumps(chunk) + '\n' self.wfile.write(response.encode('utf-8')) completion_token_count += len(encode(new_content)[0]) if req_params['stream']: chunk = { "id": cmpl_id, "object": stream_object_type, "created": created_time, "model": model, # TODO: add Lora info? resp_list: [{ "index": 0, "finish_reason": "stop", }], "usage": { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } } if stream_object_type == 'text_completion.chunk': chunk[resp_list][0]['text'] = '' else: # So yeah... do both methods? delta and messages. chunk[resp_list][0]['message'] = {'content': ''} chunk[resp_list][0]['delta'] = {} response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n' self.wfile.write(response.encode('utf-8')) # Finished if streaming. if debug: print({'response': answer}) return if debug: print({'response': answer}) completion_token_count = len(encode(answer)[0]) stop_reason = "stop" if token_count + completion_token_count >= req_params['truncation_length']: stop_reason = "length" resp = { "id": cmpl_id, "object": object_type, "created": created_time, "model": model, # TODO: add Lora info? resp_list: [{ "index": 0, "finish_reason": stop_reason, }], "usage": { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } } if is_chat: resp[resp_list][0]["message"] = {"role": "assistant", "content": answer} else: resp[resp_list][0]["text"] = answer response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) elif '/edits' in self.path: self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() created_time = int(time.time()) # Using Alpaca format, this may work with other models too. instruction = body['instruction'] input = body.get('input', '') instruction_template = deduce_template() edit_task = instruction_template.format(instruction=instruction, input=input) truncation_length = default(shared.settings, 'truncation_length', 2048) token_count = len(encode(edit_task)[0]) max_tokens = truncation_length - token_count req_params = { 'max_new_tokens': max_tokens, 'temperature': clamp(default(body, 'temperature', 1.0), 0.001, 1.999), 'top_p': clamp(default(body, 'top_p', 1.0), 0.001, 1.0), 'top_k': 1, 'repetition_penalty': 1.18, 'encoder_repetition_penalty': 1.0, 'suffix': None, 'stream': False, 'echo': False, 'seed': shared.settings.get('seed', -1), # 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map 'truncation_length': truncation_length, 'add_bos_token': shared.settings.get('add_bos_token', True), 'do_sample': True, 'typical_p': 1.0, 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, 'penalty_alpha': 0.0, 'length_penalty': 1, 'early_stopping': False, 'ban_eos_token': False, 'skip_special_tokens': True, 'custom_stopping_strings': [], } if debug: print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False) answer = '' for a in generator: answer = a # some reply's have an extra leading space to fit the instruction template, just clip it off from the reply. if edit_task[-1] != '\n' and answer and answer[0] == ' ': answer = answer[1:] completion_token_count = len(encode(answer)[0]) resp = { "object": "edit", "created": created_time, "choices": [{ "text": answer, "index": 0, }], "usage": { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } } if debug: print({'answer': answer, 'completion_token_count': completion_token_count}) response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: # Stable Diffusion callout wrapper for txt2img # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E # the results will be limited and likely poor. SD has hundreds of models and dozens of settings. # If you want high quality tailored results you should just use the Stable Diffusion API directly. # it's too general an API to try and shape the result with specific tags like "masterpiece", etc, # Will probably work best with the stock SD models. # SD configuration is beyond the scope of this API. # At this point I will not add the edits and variations endpoints (ie. img2img) because they # require changing the form data handling to accept multipart form data, also to properly support # url return types will require file management and a web serving files... Perhaps later! self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size response_format = default(body, 'response_format', 'url') # or b64_json payload = { 'prompt': body['prompt'], # ignore prompt limit of 1000 characters 'width': width, 'height': height, 'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10 } resp = { 'created': int(time.time()), 'data': [] } # TODO: support SD_WEBUI_AUTH username:password pair. sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img" response = requests.post(url=sd_url, json=payload) r = response.json() # r['parameters']... for b64_json in r['images']: if response_format == 'b64_json': resp['data'].extend([{'b64_json': b64_json}]) else: resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this response = json.dumps(resp) self.wfile.write(response.encode('utf-8')) elif '/embeddings' in self.path and embedding_model is not None: self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() input = body['input'] if 'input' in body else body['text'] if type(input) is str: input = [input] embeddings = embedding_model.encode(input).tolist() def enc_emb(emb): # If base64 is specified, encode. Otherwise, do nothing. if body.get("encoding_format", "") == "base64": return float_list_to_base64(emb) else: return emb data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)] response = json.dumps({ "object": "list", "data": data, "model": st_model, # return the real model "usage": { "prompt_tokens": 0, "total_tokens": 0, } }) if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") self.wfile.write(response.encode('utf-8')) elif '/moderations' in self.path: # for now do nothing, just don't error. self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() response = json.dumps({ "id": "modr-5MWoLO", "model": "text-moderation-001", "results": [{ "categories": { "hate": False, "hate/threatening": False, "self-harm": False, "sexual": False, "sexual/minors": False, "violence": False, "violence/graphic": False }, "category_scores": { "hate": 0.0, "hate/threatening": 0.0, "self-harm": 0.0, "sexual": 0.0, "sexual/minors": 0.0, "violence": 0.0, "violence/graphic": 0.0 }, "flagged": False }] }) self.wfile.write(response.encode('utf-8')) elif self.path == '/api/v1/token-count': # NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side. self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() tokens = encode(body['prompt'])[0] response = json.dumps({ 'results': [{ 'tokens': len(tokens) }] }) self.wfile.write(response.encode('utf-8')) else: print(self.path, self.headers) self.send_error(404) def run_server(): global embedding_model try: embedding_model = SentenceTransformer(st_model) print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}") except: print(f"\nFailed to load embedding model: {st_model}") pass server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) server = ThreadingHTTPServer(server_addr, Handler) if shared.args.share: try: from flask_cloudflared import _run_cloudflared public_url = _run_cloudflared(params['port'], params['port'] + 1) print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1') except ImportError: print('You should install flask_cloudflared manually') else: print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') server.serve_forever() def setup(): Thread(target=run_server, daemon=True).start()