Text-Api / api.py
um41r's picture
Update api.py
b25b1b9 verified
from __future__ import annotations
import os
import re
import secrets
from contextlib import asynccontextmanager
from typing import Annotated
import torch
from fastapi import FastAPI, HTTPException, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field
from transformers import pipeline
# ─── Config ────────────────────────────────────────────────────────────────────
MODEL_ID = "openai-community/roberta-base-openai-detector"
# Read from HuggingFace Space secret (Settings β†’ Variables and secrets)
API_KEY = os.environ.get("API_KEY", "")
if not API_KEY:
raise RuntimeError(
"API_KEY environment variable is not set. "
"Add it in your HuggingFace Space β†’ Settings β†’ Variables and secrets."
)
# Header scheme β€” clients send: X-API-Key: <your-key>
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
def verify_api_key(key: str | None = Security(api_key_header)) -> str:
"""Dependency: rejects requests with a missing or wrong API key."""
if not key or not secrets.compare_digest(key, API_KEY):
raise HTTPException(
status_code=401,
detail="Invalid or missing API key. Pass it as the X-API-Key header.",
)
return key
# ─── Lifespan ──────────────────────────────────────────────────────────────────
classifier = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global classifier
print(f"Loading model {MODEL_ID} …")
classifier = pipeline(
"text-classification",
model=MODEL_ID,
device=0 if torch.cuda.is_available() else -1,
)
print("Model ready.")
yield
# ─── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="AI Text Detector API",
description="Detects whether text is human-written or AI-generated. Requires X-API-Key header.",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # lock this down to your domain in production
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
# ─── Helpers ───────────────────────────────────────────────────────────────────
def split_into_chunks(text: str) -> list[str]:
chunks: list[str] = []
paragraphs = [p.strip() for p in text.split("\n") if p.strip()] or [text.strip()]
for para in paragraphs:
sentences = re.split(r"(?<=[.!?])\s+", para)
current = ""
for sent in sentences:
if len((current + " " + sent).split()) > 80:
if current.strip():
chunks.append(current.strip())
current = sent
else:
current = (current + " " + sent).strip()
if current.strip():
chunks.append(current.strip())
return chunks or [text.strip()]
# ─── Schemas ───────────────────────────────────────────────────────────────────
class DetectRequest(BaseModel):
text: Annotated[
str,
Field(
min_length=1,
max_length=10_000,
description="Text to analyse (max 10,000 characters)",
),
]
class ChunkResult(BaseModel):
text: str
ai_probability: float
human_probability: float
label: str # "AI" | "Human"
confidence: float
class DetectResponse(BaseModel):
label: str
ai_probability: float
human_probability: float
confidence: float
chunks: list[ChunkResult]
total_chunks: int
ai_chunks: int
human_chunks: int
# ─── Routes ────────────────────────────────────────────────────────────────────
@app.get("/", tags=["health"])
async def health():
"""Public health-check β€” no API key required."""
return {"status": "ok", "model": MODEL_ID}
@app.post(
"/detect",
response_model=DetectResponse,
tags=["detection"],
dependencies=[Security(verify_api_key)],
)
async def detect(body: DetectRequest):
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded yet β€” try again shortly.")
chunks = split_into_chunks(body.text)
raw = classifier(chunks, truncation=True, max_length=512, batch_size=8)
chunk_results: list[ChunkResult] = []
ai_probs: list[float] = []
word_counts: list[int] = []
for chunk, res in zip(chunks, raw):
ai_prob = res["score"] if res["label"] == "Fake" else 1.0 - res["score"]
human_prob = 1.0 - ai_prob
is_ai = ai_prob >= 0.5
label = "AI" if is_ai else "Human"
conf = ai_prob if is_ai else human_prob
chunk_results.append(
ChunkResult(
text=chunk,
ai_probability=round(ai_prob, 4),
human_probability=round(human_prob, 4),
label=label,
confidence=round(conf, 4),
)
)
ai_probs.append(ai_prob)
word_counts.append(len(chunk.split()))
total_words = sum(word_counts)
avg_ai = sum(p * w for p, w in zip(ai_probs, word_counts)) / total_words
avg_human = 1.0 - avg_ai
overall_label = "AI" if avg_ai >= 0.5 else "Human"
overall_conf = avg_ai if overall_label == "AI" else avg_human
ai_chunks = sum(1 for p in ai_probs if p >= 0.5)
return DetectResponse(
label=overall_label,
ai_probability=round(avg_ai, 4),
human_probability=round(avg_human, 4),
confidence=round(overall_conf, 4),
chunks=chunk_results,
total_chunks=len(chunks),
ai_chunks=ai_chunks,
human_chunks=len(chunks) - ai_chunks,
)