Spaces:
Runtime error
Runtime error
import json | |
import ssl | |
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
from threading import Thread | |
from extensions.api.util import build_parameters, try_start_cloudflared | |
from modules import shared | |
from modules.chat import generate_chat_reply | |
from modules.LoRA import add_lora_to_model | |
from modules.models import load_model, unload_model | |
from modules.models_settings import get_model_metadata, update_model_parameters | |
from modules.text_generation import ( | |
encode, | |
generate_reply, | |
stop_everything_event | |
) | |
from modules.utils import get_available_models | |
from modules.logging_colors import logger | |
def get_model_info(): | |
return { | |
'model_name': shared.model_name, | |
'lora_names': shared.lora_names, | |
# dump | |
'shared.settings': shared.settings, | |
'shared.args': vars(shared.args), | |
} | |
class Handler(BaseHTTPRequestHandler): | |
def do_GET(self): | |
if self.path == '/api/v1/model': | |
self.send_response(200) | |
self.end_headers() | |
response = json.dumps({ | |
'result': shared.model_name | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
self.send_error(404) | |
def do_POST(self): | |
content_length = int(self.headers['Content-Length']) | |
body = json.loads(self.rfile.read(content_length).decode('utf-8')) | |
if self.path == '/api/v1/generate': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
prompt = body['prompt'] | |
generate_params = build_parameters(body) | |
stopping_strings = generate_params.pop('stopping_strings') | |
generate_params['stream'] = False | |
generator = generate_reply( | |
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) | |
answer = '' | |
for a in generator: | |
answer = a | |
response = json.dumps({ | |
'results': [{ | |
'text': answer | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/chat': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
user_input = body['user_input'] | |
regenerate = body.get('regenerate', False) | |
_continue = body.get('_continue', False) | |
generate_params = build_parameters(body, chat=True) | |
generate_params['stream'] = False | |
generator = generate_chat_reply( | |
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) | |
answer = generate_params['history'] | |
for a in generator: | |
answer = a | |
response = json.dumps({ | |
'results': [{ | |
'history': answer | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/stop-stream': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
stop_everything_event() | |
response = json.dumps({ | |
'results': 'success' | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/model': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
# by default return the same as the GET interface | |
result = shared.model_name | |
# Actions: info, load, list, unload | |
action = body.get('action', '') | |
if action == 'load': | |
model_name = body['model_name'] | |
args = body.get('args', {}) | |
print('args', args) | |
for k in args: | |
setattr(shared.args, k, args[k]) | |
shared.model_name = model_name | |
unload_model() | |
model_settings = get_model_metadata(shared.model_name) | |
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) | |
update_model_parameters(model_settings, initial=True) | |
if shared.settings['mode'] != 'instruct': | |
shared.settings['instruction_template'] = None | |
try: | |
shared.model, shared.tokenizer = load_model(shared.model_name) | |
if shared.args.lora: | |
add_lora_to_model(shared.args.lora) # list | |
except Exception as e: | |
response = json.dumps({'error': {'message': repr(e)}}) | |
self.wfile.write(response.encode('utf-8')) | |
raise e | |
shared.args.model = shared.model_name | |
result = get_model_info() | |
elif action == 'unload': | |
unload_model() | |
shared.model_name = None | |
shared.args.model = None | |
result = get_model_info() | |
elif action == 'list': | |
result = get_available_models() | |
elif action == 'info': | |
result = get_model_info() | |
response = json.dumps({ | |
'result': result, | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/token-count': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
tokens = encode(body['prompt'])[0] | |
response = json.dumps({ | |
'results': [{ | |
'tokens': len(tokens) | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
self.send_error(404) | |
def do_OPTIONS(self): | |
self.send_response(200) | |
self.end_headers() | |
def end_headers(self): | |
self.send_header('Access-Control-Allow-Origin', '*') | |
self.send_header('Access-Control-Allow-Methods', '*') | |
self.send_header('Access-Control-Allow-Headers', '*') | |
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate') | |
super().end_headers() | |
def _run_server(port: int, share: bool = False, tunnel_id=str): | |
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' | |
server = ThreadingHTTPServer((address, port), Handler) | |
ssl_certfile = shared.args.ssl_certfile | |
ssl_keyfile = shared.args.ssl_keyfile | |
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False | |
if ssl_verify: | |
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
context.load_cert_chain(ssl_certfile, ssl_keyfile) | |
server.socket = context.wrap_socket(server.socket, server_side=True) | |
def on_start(public_url: str): | |
logger.info(f'Blocking API URL: \n\n{public_url}/api\n') | |
if share: | |
try: | |
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) | |
except Exception: | |
pass | |
else: | |
if ssl_verify: | |
logger.info(f'Blocking API URL: \n\nhttps://{address}:{port}/api\n') | |
else: | |
logger.info(f'Blocking API URL: \n\nhttp://{address}:{port}/api\n') | |
server.serve_forever() | |
def start_server(port: int, share: bool = False, tunnel_id=str): | |
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() | |