|
""" |
|
This module is responsible for the VectorDB API. It currently supports: |
|
* DELETE api/v1/clear |
|
- Clears the whole DB. |
|
* POST api/v1/add |
|
- Add some corpus to the DB. You can also specify metadata to be added alongside it. |
|
* POST api/v1/delete |
|
- Delete specific records with given metadata. |
|
* POST api/v1/get |
|
- Get results from chromaDB. |
|
""" |
|
|
|
import json |
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer |
|
from urllib.parse import urlparse, parse_qs |
|
from threading import Thread |
|
|
|
from modules import shared |
|
from modules.logging_colors import logger |
|
|
|
from .chromadb import ChromaCollector |
|
from .data_processor import process_and_add_to_collector |
|
|
|
import extensions.superboogav2.parameters as parameters |
|
|
|
|
|
class CustomThreadingHTTPServer(ThreadingHTTPServer): |
|
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True): |
|
self.collector = collector |
|
super().__init__(server_address, RequestHandlerClass, bind_and_activate) |
|
|
|
def finish_request(self, request, client_address): |
|
self.RequestHandlerClass(request, client_address, self, self.collector) |
|
|
|
|
|
class Handler(BaseHTTPRequestHandler): |
|
def __init__(self, request, client_address, server, collector: ChromaCollector): |
|
self.collector = collector |
|
super().__init__(request, client_address, server) |
|
|
|
|
|
def _send_412_error(self, message): |
|
self.send_response(412) |
|
self.send_header("Content-type", "application/json") |
|
self.end_headers() |
|
response = json.dumps({"error": message}) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
|
|
def _send_404_error(self): |
|
self.send_response(404) |
|
self.send_header("Content-type", "application/json") |
|
self.end_headers() |
|
response = json.dumps({"error": "Resource not found"}) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
|
|
def _send_400_error(self, error_message: str): |
|
self.send_response(400) |
|
self.send_header("Content-type", "application/json") |
|
self.end_headers() |
|
response = json.dumps({"error": error_message}) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
|
|
def _send_200_response(self, message: str): |
|
self.send_response(200) |
|
self.send_header("Content-type", "application/json") |
|
self.end_headers() |
|
|
|
if isinstance(message, str): |
|
response = json.dumps({"message": message}) |
|
else: |
|
response = json.dumps(message) |
|
|
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
|
|
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str): |
|
if sort_param == parameters.SORT_DISTANCE: |
|
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count) |
|
elif sort_param == parameters.SORT_ID: |
|
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count) |
|
else: |
|
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count) |
|
|
|
return { |
|
"results": results |
|
} |
|
|
|
|
|
def do_GET(self): |
|
self._send_404_error() |
|
|
|
|
|
def do_POST(self): |
|
try: |
|
content_length = int(self.headers['Content-Length']) |
|
body = json.loads(self.rfile.read(content_length).decode('utf-8')) |
|
|
|
parsed_path = urlparse(self.path) |
|
path = parsed_path.path |
|
query_params = parse_qs(parsed_path.query) |
|
|
|
if path in ['/api/v1/add', '/api/add']: |
|
corpus = body.get('corpus') |
|
if corpus is None: |
|
self._send_412_error("Missing parameter 'corpus'") |
|
return |
|
|
|
clear_before_adding = body.get('clear_before_adding', False) |
|
metadata = body.get('metadata') |
|
process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata) |
|
self._send_200_response("Data successfully added") |
|
|
|
elif path in ['/api/v1/delete', '/api/delete']: |
|
metadata = body.get('metadata') |
|
if corpus is None: |
|
self._send_412_error("Missing parameter 'metadata'") |
|
return |
|
|
|
self.collector.delete(ids_to_delete=None, where=metadata) |
|
self._send_200_response("Data successfully deleted") |
|
|
|
elif path in ['/api/v1/get', '/api/get']: |
|
search_strings = body.get('search_strings') |
|
if search_strings is None: |
|
self._send_412_error("Missing parameter 'search_strings'") |
|
return |
|
|
|
n_results = body.get('n_results') |
|
if n_results is None: |
|
n_results = parameters.get_chunk_count() |
|
|
|
max_token_count = body.get('max_token_count') |
|
if max_token_count is None: |
|
max_token_count = parameters.get_max_token_count() |
|
|
|
sort_param = query_params.get('sort', ['distance'])[0] |
|
|
|
results = self._handle_get(search_strings, n_results, max_token_count, sort_param) |
|
self._send_200_response(results) |
|
|
|
else: |
|
self._send_404_error() |
|
except Exception as e: |
|
self._send_400_error(str(e)) |
|
|
|
|
|
def do_DELETE(self): |
|
try: |
|
parsed_path = urlparse(self.path) |
|
path = parsed_path.path |
|
query_params = parse_qs(parsed_path.query) |
|
|
|
if path in ['/api/v1/clear', '/api/clear']: |
|
self.collector.clear() |
|
self._send_200_response("Data successfully cleared") |
|
else: |
|
self._send_404_error() |
|
except Exception as e: |
|
self._send_400_error(str(e)) |
|
|
|
|
|
def do_OPTIONS(self): |
|
self.send_response(200) |
|
self.end_headers() |
|
|
|
|
|
def end_headers(self): |
|
self.send_header('Access-Control-Allow-Origin', '*') |
|
self.send_header('Access-Control-Allow-Methods', '*') |
|
self.send_header('Access-Control-Allow-Headers', '*') |
|
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate') |
|
super().end_headers() |
|
|
|
|
|
class APIManager: |
|
def __init__(self, collector: ChromaCollector): |
|
self.server = None |
|
self.collector = collector |
|
self.is_running = False |
|
|
|
def start_server(self, port: int): |
|
if self.server is not None: |
|
print("Server already running.") |
|
return |
|
|
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' |
|
self.server = CustomThreadingHTTPServer((address, port), Handler, self.collector) |
|
|
|
logger.info(f'Starting chromaDB API at http://{address}:{port}/api') |
|
|
|
Thread(target=self.server.serve_forever, daemon=True).start() |
|
|
|
self.is_running = True |
|
|
|
def stop_server(self): |
|
if self.server is not None: |
|
logger.info(f'Stopping chromaDB API.') |
|
self.server.shutdown() |
|
self.server.server_close() |
|
self.server = None |
|
self.is_running = False |
|
|
|
def is_server_running(self): |
|
return self.is_running |