Ayush0708's picture
Update app.py
f1cf5fe verified
print("Starting app...")
import os
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
os.environ["NUMBA_DISABLE_JIT"] = "1"
print("Importing FastAPI and BERTopic...")
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from bertopic import BERTopic
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.cluster import KMeans
print("Setting up BERTopic model...")
vectorizer_model = CountVectorizer()
dimensionality_model = TruncatedSVD(n_components=5)
clustering_model = KMeans(n_clusters=5, random_state=42)
topic_model = BERTopic(
vectorizer_model=vectorizer_model,
umap_model=dimensionality_model,
hdbscan_model=clustering_model
)
print("BERTopic model ready.")
app = FastAPI()
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
if "text" in data:
text = data["text"]
elif "data" in data and isinstance(data["data"], list):
text = data["data"][0]
else:
return JSONResponse({"error": "No input text provided."}, status_code=400)
documents = [doc.strip() for doc in text.split("\n") if doc.strip()]
if not documents:
return JSONResponse({"error": "No valid input."}, status_code=400)
topics, probs = topic_model.fit_transform(documents)
topic_info = topic_model.get_topic_info()
return {
"topics": topic_info.to_dict(orient="records"),
"topic_assignments": topics
}
@app.get("/")
async def root():
return {"message": "BERTopic FastAPI is running! Use POST /predict with {'text': '...'}."}