| | """ |
| | AIFinder Flask API |
| | Serves the trained sklearn ensemble via the AIFinder inference class. |
| | """ |
| |
|
| | import os |
| | import re |
| |
|
| | import joblib |
| | import numpy as np |
| | from sklearn.ensemble import RandomForestClassifier |
| | from flask import Flask, jsonify, request, send_from_directory, render_template |
| | from flask_cors import CORS |
| | from flask_limiter import Limiter |
| | from flask_limiter.util import get_remote_address |
| |
|
| | from config import MODEL_DIR |
| | from inference import AIFinder |
| |
|
| | app = Flask(__name__) |
| | CORS(app) |
| | limiter = Limiter(get_remote_address, app=app) |
| |
|
| | finder: AIFinder | None = None |
| | community_finder: AIFinder | None = None |
| | using_community = False |
| |
|
| | DEFAULT_TOP_N = 4 |
| | COMMUNITY_DIR = os.path.join(MODEL_DIR, "community") |
| | CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib") |
| | corrections: list[dict] = [] |
| |
|
| |
|
| | def load_models(): |
| | global finder, community_finder, corrections |
| | finder = AIFinder(model_dir=MODEL_DIR) |
| | os.makedirs(COMMUNITY_DIR, exist_ok=True) |
| | if os.path.exists(CORRECTIONS_FILE): |
| | corrections = joblib.load(CORRECTIONS_FILE) |
| | if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")): |
| | try: |
| | community_finder = AIFinder(model_dir=COMMUNITY_DIR) |
| | except Exception: |
| | community_finder = None |
| |
|
| |
|
| | def _active_finder(): |
| | return community_finder if using_community and community_finder else finder |
| |
|
| |
|
| | def _strip_think_tags(text): |
| | text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL) |
| | return text.strip() |
| |
|
| |
|
| | @app.route("/") |
| | def index(): |
| | return render_template("index.html") |
| |
|
| |
|
| | @app.route("/api/classify", methods=["POST"]) |
| | @app.route("/v1/classify", methods=["POST"]) |
| | @limiter.limit("60/minute") |
| | def v1_classify(): |
| | data = request.get_json(silent=True) |
| | if not data or "text" not in data: |
| | return jsonify({"error": "Request body must be JSON with a 'text' field."}), 400 |
| |
|
| | raw_text = data["text"] |
| | text = _strip_think_tags(raw_text) |
| | af = _active_finder() |
| | top_n = min(data.get("top_n", DEFAULT_TOP_N), len(af.le.classes_)) |
| |
|
| | if not isinstance(top_n, int) or top_n < 1: |
| | top_n = DEFAULT_TOP_N |
| |
|
| | if len(text) < 20: |
| | return jsonify( |
| | { |
| | "error": "Text too short (minimum 20 characters after stripping think tags)." |
| | } |
| | ), 400 |
| |
|
| | proba = af.predict_proba(text) |
| | sorted_providers = sorted(proba.items(), key=lambda x: x[1], reverse=True)[:top_n] |
| |
|
| | top_providers = [ |
| | {"name": name, "confidence": round(float(conf * 100), 2)} |
| | for name, conf in sorted_providers |
| | ] |
| |
|
| | return jsonify( |
| | { |
| | "provider": top_providers[0]["name"], |
| | "confidence": top_providers[0]["confidence"], |
| | "top_providers": top_providers, |
| | } |
| | ) |
| |
|
| |
|
| | @app.route("/api/correct", methods=["POST"]) |
| | def correct(): |
| | global community_finder |
| | data = request.get_json(silent=True) |
| | if not data or "text" not in data or "correct_provider" not in data: |
| | return jsonify({"error": "Need 'text' and 'correct_provider'."}), 400 |
| |
|
| | provider = data["correct_provider"] |
| | if provider not in list(finder.le.classes_): |
| | return jsonify({"error": f"Unknown provider: {provider}"}), 400 |
| |
|
| | text = _strip_think_tags(data["text"]) |
| | corrections.append({"text": text, "provider": provider}) |
| |
|
| | texts = [c["text"] for c in corrections] |
| | providers = [c["provider"] for c in corrections] |
| | X = finder.pipeline.transform(texts) |
| | y = finder.le.transform(providers) |
| |
|
| | rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1) |
| | rf.fit(X, y) |
| |
|
| | joblib.dump([rf], os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")) |
| | joblib.dump(finder.pipeline, os.path.join(COMMUNITY_DIR, "pipeline_4provider.joblib")) |
| | joblib.dump(finder.le, os.path.join(COMMUNITY_DIR, "enc_4provider.joblib")) |
| | joblib.dump(corrections, CORRECTIONS_FILE) |
| |
|
| | community_finder = AIFinder(model_dir=COMMUNITY_DIR) |
| |
|
| | return jsonify({"status": "ok", "loss": 0.0, "corrections": len(corrections)}) |
| |
|
| |
|
| | @app.route("/api/save", methods=["POST"]) |
| | def save_model(): |
| | if community_finder is None: |
| | return jsonify({"error": "No community model trained yet."}), 400 |
| | filename = "community_rf_4provider.joblib" |
| | return jsonify({"status": "ok", "filename": filename}) |
| |
|
| |
|
| | @app.route("/api/toggle_community", methods=["POST"]) |
| | def toggle_community(): |
| | global using_community |
| | data = request.get_json(silent=True) or {} |
| | using_community = bool(data.get("enabled", not using_community)) |
| | return jsonify({"using_community": using_community, "available": community_finder is not None}) |
| |
|
| |
|
| | @app.route("/models/<filename>") |
| | def download_model(filename): |
| | if filename.startswith("community_"): |
| | return send_from_directory(COMMUNITY_DIR, filename.replace("community_", "", 1)) |
| | return send_from_directory(MODEL_DIR, filename) |
| |
|
| |
|
| | @app.route("/api/status", methods=["GET"]) |
| | def status(): |
| | af = _active_finder() |
| | return jsonify( |
| | { |
| | "loaded": af is not None, |
| | "device": "cpu", |
| | "providers": list(af.le.classes_) if af else [], |
| | "num_providers": len(af.le.classes_) if af else 0, |
| | "using_community": using_community, |
| | "community_available": community_finder is not None, |
| | "corrections_count": len(corrections), |
| | } |
| | ) |
| |
|
| |
|
| | @app.route("/api/providers", methods=["GET"]) |
| | def providers(): |
| | return jsonify( |
| | { |
| | "providers": list(finder.le.classes_) if finder else [], |
| | } |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("Loading models...") |
| | load_models() |
| | print( |
| | f"Ready on cpu — {len(finder.le.classes_)} providers: " |
| | f"{', '.join(finder.le.classes_)}" |
| | ) |
| | app.run(host="0.0.0.0", port=7860) |
| |
|