from functools import wraps from flask import ( Flask, jsonify, request, render_template_string, abort, send_from_directory, send_file, ) from flask_cors import CORS import unicodedata import markdown import time import os import gc import base64 from io import BytesIO from random import randint import hashlib import chromadb import posthog import torch from chromadb.config import Settings from sentence_transformers import SentenceTransformer from werkzeug.middleware.proxy_fix import ProxyFix from transformers import AutoTokenizer, AutoProcessor, pipeline from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from transformers import BlipForConditionalGeneration, GPT2Tokenizer from PIL import Image import webuiapi from colorama import Fore, Style, init as colorama_init colorama_init() port = 7860 host = "0.0.0.0" summarization_model = ( "Qiliang/bart-large-cnn-samsum-ChatGPT_v3" ) classification_model = ( "joeddav/distilbert-base-uncased-go-emotions-student" ) captioning_model = ( "Salesforce/blip-image-captioning-large" ) device_string = "cpu" device = torch.device(device_string) torch_dtype = torch.float32 if device_string == "cpu" else torch.float16 embedding_model = 'sentence-transformers/all-mpnet-base-v2' print("Initializing a text summarization model...") summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model) summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained( summarization_model, torch_dtype=torch_dtype).to(device) print("Initializing a sentiment classification pipeline...") classification_pipe = pipeline( "text-classification", model=classification_model, top_k=None, device=device, torch_dtype=torch_dtype, ) print("Initializing an image captioning model...") captioning_processor = AutoProcessor.from_pretrained(captioning_model) if "blip" in captioning_model: captioning_transformer = BlipForConditionalGeneration.from_pretrained( captioning_model, torch_dtype=torch_dtype ).to(device) else: captioning_transformer = AutoModelForCausalLM.from_pretrained( captioning_model, torch_dtype=torch_dtype ).to(device) print("Initializing ChromaDB") # disable chromadb telemetry posthog.capture = lambda *args, **kwargs: None chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) chromadb_embedder = SentenceTransformer(embedding_model) chromadb_embed_fn = chromadb_embedder.encode # Flask init app = Flask(__name__) CORS(app) # allow cross-domain requests app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 app.wsgi_app = ProxyFix( app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 ) def get_real_ip(): return request.remote_addr def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str: inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to( device, torch_dtype ) outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens) caption = captioning_processor.decode(outputs[0], skip_special_tokens=True) return caption def classify_text(text: str) -> list: output = classification_pipe( text, truncation=True, max_length=classification_pipe.model.config.max_position_embeddings, )[0] return sorted(output, key=lambda x: x["score"], reverse=True) def summarize_chunks(text: str, params: dict) -> str: try: return summarize(text, params) except IndexError: print( "Sequence length too large for model, cutting text in half and calling again" ) new_params = params.copy() new_params["max_length"] = new_params["max_length"] // 2 new_params["min_length"] = new_params["min_length"] // 2 return summarize_chunks( text[: (len(text) // 2)], new_params ) + summarize_chunks(text[(len(text) // 2) :], new_params) def summarize(text: str, params: dict) -> str: # Tokenize input inputs = summarization_tokenizer(text, return_tensors="pt").to(device) token_count = len(inputs[0]) bad_words_ids = [ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in params["bad_words"] ] summary_ids = summarization_transformer.generate( inputs["input_ids"], num_beams=2, max_new_tokens=max(token_count, int(params["max_length"])), min_new_tokens=min(token_count, int(params["min_length"])), repetition_penalty=float(params["repetition_penalty"]), temperature=float(params["temperature"]), length_penalty=float(params["length_penalty"]), bad_words_ids=bad_words_ids, ) summary = summarization_tokenizer.batch_decode( summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True )[0] summary = normalize_string(summary) return summary def normalize_string(input: str) -> str: output = " ".join(unicodedata.normalize("NFKC", input).strip().split()) return output @app.before_request def before_request(): # Request time measuring request.start_time = time.time() # Checks if an API key is present and valid, otherwise return unauthorized # The options check is required so CORS doesn't get angry try: if request.method != 'OPTIONS' and getattr(request.authorization, 'token', '') != os.environ['sekrit_password']: print(f"WARNING: Unauthorized API key access from {request.remote_addr}") response = jsonify({ 'error': '401: Invalid API key' }) response.status_code = 401 return response except Exception as e: print(f"API key check error: {e}") return "401 Unauthorized\n{}\n\n".format(e), 401 @app.after_request def after_request(response): duration = time.time() - request.start_time response.headers["X-Request-Duration"] = str(duration) return response @app.route("/", methods=["GET"]) def index(): with open("./README.md", "r", encoding="utf8") as f: content = f.read() return render_template_string(markdown.markdown(content, extensions=["tables"])) @app.route("/api/modules", methods=["GET"]) def get_modules(): return jsonify({"modules": ['chromadb','summarize','classify','caption']}) @app.route("/api/chromadb", methods=["POST"]) def chromadb_add_messages(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') if "messages" not in data or not isinstance(data["messages"], list): abort(400, '"messages" is required') ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) documents = [m["content"] for m in data["messages"]] ids = [m["id"] for m in data["messages"]] metadatas = [ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} for m in data["messages"] ] if len(ids) > 0: collection.upsert( ids=ids, documents=documents, metadatas=metadatas, ) return jsonify({"count": len(ids)}) @app.route("/api/chromadb/query", methods=["POST"]) def chromadb_query(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') if "query" not in data or not isinstance(data["query"], str): abort(400, '"query" is required') if "n_results" not in data or not isinstance(data["n_results"], int): n_results = 1 else: n_results = data["n_results"] ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) n_results = min(collection.count(), n_results) messages = [] if n_results > 0: query_result = collection.query( query_texts=[data["query"]], n_results=n_results, ) documents = query_result["documents"][0] ids = query_result["ids"][0] metadatas = query_result["metadatas"][0] distances = query_result["distances"][0] messages = [ { "id": ids[i], "date": metadatas[i]["date"], "role": metadatas[i]["role"], "meta": metadatas[i]["meta"], "content": documents[i], "distance": distances[i], } for i in range(len(ids)) ] return jsonify(messages) @app.route("/api/chromadb/purge", methods=["POST"]) def chromadb_purge(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) deleted = collection.delete() print("ChromaDB embeddings deleted", len(deleted)) return 'Ok', 200 @app.route("/api/summarize", methods=["POST"]) def api_summarize(): data = request.get_json() if "text" not in data or not isinstance(data["text"], str): abort(400, '"text" is required') params = { "temperature": 1.0, "repetition_penalty": 1.0, "max_length": 500, "min_length": 200, "length_penalty": 1.5, "bad_words": [ "\n", '"', "*", "[", "]", "{", "}", ":", "(", ")", "<", ">", "Â", "The text ends", "The story ends", "The text is", "The story is", ], } if "params" in data and isinstance(data["params"], dict): params.update(data["params"]) print("Summary input:", data["text"], sep="\n") summary = summarize_chunks(data["text"], params) print("Summary output:", summary, sep="\n") gc.collect() return jsonify({"summary": summary}) @app.route("/api/classify", methods=["POST"]) def api_classify(): data = request.get_json() if "text" not in data or not isinstance(data["text"], str): abort(400, '"text" is required') print("Classification input:", data["text"], sep="\n") classification = classify_text(data["text"]) print("Classification output:", classification, sep="\n") gc.collect() return jsonify({"classification": classification}) @app.route("/api/classify/labels", methods=["GET"]) def api_classify_labels(): classification = classify_text("") labels = [x["label"] for x in classification] return jsonify({"labels": labels}) app.run(host=host, port=port)