from fastapi import FastAPI from sentence_transformers import SentenceTransformer import pickle import os from pydantic import BaseModel import numpy as np from typing import List app = FastAPI( title="SBERT Embedding API", description="API for generating sentence embeddings using SBERT", version="1.0" ) # Load model (this will be cached after first load) model_name = 'taghyan/model' model = SentenceTransformer(model_name) # Embedding cache setup embedding_file = 'embeddings_sbert.pkl' class TextRequest(BaseModel): text: str class TextsRequest(BaseModel): texts: List[str] class EmbeddingResponse(BaseModel): embedding: List[float] class EmbeddingsResponse(BaseModel): embeddings: List[List[float]] @app.get("/") def read_root(): return {"message": "SBERT Embedding Service"} @app.post("/embed", response_model=EmbeddingResponse) async def embed_text(request: TextRequest): """Generate embedding for a single text""" embedding = model.encode(request.text, convert_to_numpy=True).tolist() return {"embedding": embedding} @app.post("/embed_batch", response_model=EmbeddingsResponse) async def embed_texts(request: TextsRequest): """Generate embeddings for multiple texts""" embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist() return {"embeddings": embeddings} @app.post("/update_cache") async def update_cache(request: TextsRequest): """Update the embedding cache with new texts""" if os.path.exists(embedding_file): with open(embedding_file, 'rb') as f: existing_embeddings = pickle.load(f) else: existing_embeddings = [] new_embeddings = model.encode(request.texts, show_progress_bar=True) updated_embeddings = existing_embeddings + new_embeddings.tolist() with open(embedding_file, 'wb') as f: pickle.dump(updated_embeddings, f) return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)}