AIFinder / app.py
CompactAI's picture
Upload 8 files
17ef86f verified
"""
AIFinder Flask API
Serves the trained sklearn ensemble via the AIFinder inference class.
"""
import os
import re
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from flask import Flask, jsonify, request, send_from_directory, render_template
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from config import MODEL_DIR
from inference import AIFinder
app = Flask(__name__)
CORS(app)
limiter = Limiter(get_remote_address, app=app)
finder: AIFinder | None = None
community_finder: AIFinder | None = None
using_community = False
DEFAULT_TOP_N = 4
COMMUNITY_DIR = os.path.join(MODEL_DIR, "community")
CORRECTIONS_FILE = os.path.join(COMMUNITY_DIR, "corrections.joblib")
corrections: list[dict] = []
def load_models():
global finder, community_finder, corrections
finder = AIFinder(model_dir=MODEL_DIR)
os.makedirs(COMMUNITY_DIR, exist_ok=True)
if os.path.exists(CORRECTIONS_FILE):
corrections = joblib.load(CORRECTIONS_FILE)
if os.path.exists(os.path.join(COMMUNITY_DIR, "rf_4provider.joblib")):
try:
community_finder = AIFinder(model_dir=COMMUNITY_DIR)
except Exception:
community_finder = None
def _active_finder():
return community_finder if using_community and community_finder else finder
def _strip_think_tags(text):
text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL)
return text.strip()
@app.route("/")
def index():
return render_template("index.html")
@app.route("/api/classify", methods=["POST"])
@app.route("/v1/classify", methods=["POST"])
@limiter.limit("60/minute")
def v1_classify():
data = request.get_json(silent=True)
if not data or "text" not in data:
return jsonify({"error": "Request body must be JSON with a 'text' field."}), 400
raw_text = data["text"]
text = _strip_think_tags(raw_text)
af = _active_finder()
top_n = min(data.get("top_n", DEFAULT_TOP_N), len(af.le.classes_))
if not isinstance(top_n, int) or top_n < 1:
top_n = DEFAULT_TOP_N
if len(text) < 20:
return jsonify(
{
"error": "Text too short (minimum 20 characters after stripping think tags)."
}
), 400
proba = af.predict_proba(text)
sorted_providers = sorted(proba.items(), key=lambda x: x[1], reverse=True)[:top_n]
top_providers = [
{"name": name, "confidence": round(float(conf * 100), 2)}
for name, conf in sorted_providers
]
return jsonify(
{
"provider": top_providers[0]["name"],
"confidence": top_providers[0]["confidence"],
"top_providers": top_providers,
}
)
@app.route("/api/correct", methods=["POST"])
def correct():
global community_finder
data = request.get_json(silent=True)
if not data or "text" not in data or "correct_provider" not in data:
return jsonify({"error": "Need 'text' and 'correct_provider'."}), 400
provider = data["correct_provider"]
if provider not in list(finder.le.classes_):
return jsonify({"error": f"Unknown provider: {provider}"}), 400
text = _strip_think_tags(data["text"])
corrections.append({"text": text, "provider": provider})
texts = [c["text"] for c in corrections]
providers = [c["provider"] for c in corrections]
X = finder.pipeline.transform(texts)
y = finder.le.transform(providers)
rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
rf.fit(X, y)
joblib.dump([rf], os.path.join(COMMUNITY_DIR, "rf_4provider.joblib"))
joblib.dump(finder.pipeline, os.path.join(COMMUNITY_DIR, "pipeline_4provider.joblib"))
joblib.dump(finder.le, os.path.join(COMMUNITY_DIR, "enc_4provider.joblib"))
joblib.dump(corrections, CORRECTIONS_FILE)
community_finder = AIFinder(model_dir=COMMUNITY_DIR)
return jsonify({"status": "ok", "loss": 0.0, "corrections": len(corrections)})
@app.route("/api/save", methods=["POST"])
def save_model():
if community_finder is None:
return jsonify({"error": "No community model trained yet."}), 400
filename = "community_rf_4provider.joblib"
return jsonify({"status": "ok", "filename": filename})
@app.route("/api/toggle_community", methods=["POST"])
def toggle_community():
global using_community
data = request.get_json(silent=True) or {}
using_community = bool(data.get("enabled", not using_community))
return jsonify({"using_community": using_community, "available": community_finder is not None})
@app.route("/models/<filename>")
def download_model(filename):
if filename.startswith("community_"):
return send_from_directory(COMMUNITY_DIR, filename.replace("community_", "", 1))
return send_from_directory(MODEL_DIR, filename)
@app.route("/api/status", methods=["GET"])
def status():
af = _active_finder()
return jsonify(
{
"loaded": af is not None,
"device": "cpu",
"providers": list(af.le.classes_) if af else [],
"num_providers": len(af.le.classes_) if af else 0,
"using_community": using_community,
"community_available": community_finder is not None,
"corrections_count": len(corrections),
}
)
@app.route("/api/providers", methods=["GET"])
def providers():
return jsonify(
{
"providers": list(finder.le.classes_) if finder else [],
}
)
if __name__ == "__main__":
print("Loading models...")
load_models()
print(
f"Ready on cpu — {len(finder.le.classes_)} providers: "
f"{', '.join(finder.le.classes_)}"
)
app.run(host="0.0.0.0", port=7860)