File size: 1,629 Bytes
f1cf5fe
a4f7aef
 
 
 
f1cf5fe
9ce511e
 
 
 
 
 
 
f1cf5fe
9ce511e
 
 
 
 
 
 
 
 
f1cf5fe
9ce511e
 
 
 
 
 
 
 
 
f1cf5fe
9ce511e
 
 
 
 
 
 
 
 
 
a4f7aef
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
print("Starting app...")
import os
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
os.environ["NUMBA_DISABLE_JIT"] = "1"

print("Importing FastAPI and BERTopic...")
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from bertopic import BERTopic
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.cluster import KMeans

print("Setting up BERTopic model...")
vectorizer_model = CountVectorizer()
dimensionality_model = TruncatedSVD(n_components=5)
clustering_model = KMeans(n_clusters=5, random_state=42)

topic_model = BERTopic(
    vectorizer_model=vectorizer_model,
    umap_model=dimensionality_model,
    hdbscan_model=clustering_model
)
print("BERTopic model ready.")

app = FastAPI()

@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    if "text" in data:
        text = data["text"]
    elif "data" in data and isinstance(data["data"], list):
        text = data["data"][0]
    else:
        return JSONResponse({"error": "No input text provided."}, status_code=400)
    documents = [doc.strip() for doc in text.split("\n") if doc.strip()]
    if not documents:
        return JSONResponse({"error": "No valid input."}, status_code=400)
    topics, probs = topic_model.fit_transform(documents)
    topic_info = topic_model.get_topic_info()
    return {
        "topics": topic_info.to_dict(orient="records"),
        "topic_assignments": topics
    }

@app.get("/")
async def root():
    return {"message": "BERTopic FastAPI is running! Use POST /predict with {'text': '...'}."}