import json import os import traceback from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread from modules import shared from extensions.openai.tokens import token_count, token_encode, token_decode import extensions.openai.models as OAImodels import extensions.openai.edits as OAIedits import extensions.openai.embeddings as OAIembeddings import extensions.openai.images as OAIimages import extensions.openai.moderations as OAImoderations import extensions.openai.completions as OAIcompletions from extensions.openai.errors import * from extensions.openai.utils import debug_msg from extensions.openai.defaults import (get_default_req_params, default, clamp) params = { 'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, } class Handler(BaseHTTPRequestHandler): def send_access_control_headers(self): self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Credentials", "true") self.send_header( "Access-Control-Allow-Methods", "GET,HEAD,OPTIONS,POST,PUT" ) self.send_header( "Access-Control-Allow-Headers", "Origin, Accept, X-Requested-With, Content-Type, " "Access-Control-Request-Method, Access-Control-Request-Headers, " "Authorization" ) def do_OPTIONS(self): self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() self.wfile.write("OK".encode('utf-8')) def start_sse(self): self.send_response(200) self.send_access_control_headers() self.send_header('Content-Type', 'text/event-stream') self.send_header('Cache-Control', 'no-cache') # self.send_header('Connection', 'keep-alive') self.end_headers() def send_sse(self, chunk: dict): response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' debug_msg(response) self.wfile.write(response.encode('utf-8')) def end_sse(self): self.wfile.write('data: [DONE]\r\n\r\n'.encode('utf-8')) def return_json(self, ret: dict, code: int = 200, no_debug=False): self.send_response(code) self.send_access_control_headers() self.send_header('Content-Type', 'application/json') self.end_headers() response = json.dumps(ret) r_utf8 = response.encode('utf-8') self.wfile.write(r_utf8) if not no_debug: debug_msg(r_utf8) def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''): error_resp = { 'error': { 'message': message, 'code': code, 'type': error_type, 'param': param, } } if internal_message: print(internal_message) # error_resp['internal_message'] = internal_message self.return_json(error_resp, code) def openai_error_handler(func): def wrapper(self): try: func(self) except ServiceUnavailableError as e: self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message) except InvalidRequestError as e: self.openai_error(e.message, e.code, e.error_type, e.param, internal_message=e.internal_message) except OpenAIError as e: self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message) except Exception as e: self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc()) return wrapper @openai_error_handler def do_GET(self): debug_msg(self.requestline) debug_msg(self.headers) if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'): is_legacy = 'engines' in self.path is_list = self.path in ['/v1/engines', '/v1/models'] if is_legacy and not is_list: model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):] resp = OAImodels.load_model(model_name) elif is_list: resp = OAImodels.list_models(is_legacy) else: model_name = self.path[len('/v1/models/'):] resp = OAImodels.model_info() self.return_json(resp) elif '/billing/usage' in self.path: # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 self.return_json({"total_usage": 0}, no_debug=True) else: self.send_error(404) @openai_error_handler def do_POST(self): debug_msg(self.requestline) debug_msg(self.headers) content_length = int(self.headers['Content-Length']) body = json.loads(self.rfile.read(content_length).decode('utf-8')) debug_msg(body) if '/completions' in self.path or '/generate' in self.path: if not shared.model: self.openai_error("No model loaded.") return is_legacy = '/generate' in self.path is_streaming = body.get('stream', False) if is_streaming: self.start_sse() response = [] if 'chat' in self.path: response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy) else: response = OAIcompletions.stream_completions(body, is_legacy=is_legacy) for resp in response: self.send_sse(resp) self.end_sse() else: response = '' if 'chat' in self.path: response = OAIcompletions.chat_completions(body, is_legacy=is_legacy) else: response = OAIcompletions.completions(body, is_legacy=is_legacy) self.return_json(response) elif '/edits' in self.path: # deprecated if not shared.model: self.openai_error("No model loaded.") return req_params = get_default_req_params() instruction = body['instruction'] input = body.get('input', '') temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0) response = OAIedits.edits(instruction, input, temperature, top_p) self.return_json(response) elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: prompt = body['prompt'] size = default(body, 'size', '1024x1024') response_format = default(body, 'response_format', 'url') # or b64_json n = default(body, 'n', 1) # ignore the batch limits of max 10 response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) self.return_json(response, no_debug=True) elif '/embeddings' in self.path: encoding_format = body.get('encoding_format', '') input = body.get('input', body.get('text', '')) if not input: raise InvalidRequestError("Missing required argument input", params='input') if type(input) is str: input = [input] response = OAIembeddings.embeddings(input, encoding_format) self.return_json(response, no_debug=True) elif '/moderations' in self.path: input = body['input'] if not input: raise InvalidRequestError("Missing required argument input", params='input') response = OAImoderations.moderations(input) self.return_json(response, no_debug=True) elif self.path == '/api/v1/token-count': # NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side. response = token_count(body['prompt']) self.return_json(response, no_debug=True) elif self.path == '/api/v1/token/encode': # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models encoding_format = body.get('encoding_format', '') response = token_encode(body['input'], encoding_format) self.return_json(response, no_debug=True) elif self.path == '/api/v1/token/decode': # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models encoding_format = body.get('encoding_format', '') response = token_decode(body['input'], encoding_format) self.return_json(response, no_debug=True) else: self.send_error(404) def run_server(): server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) server = ThreadingHTTPServer(server_addr, Handler) if shared.args.share: try: from flask_cloudflared import _run_cloudflared public_url = _run_cloudflared(params['port'], params['port'] + 1) print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1') except ImportError: print('You should install flask_cloudflared manually') else: print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') server.serve_forever() def setup(): Thread(target=run_server, daemon=True).start()