File size: 2,049 Bytes
cbd9add |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from fastapi import FastAPI
from sentence_transformers import SentenceTransformer
import pickle
import os
from pydantic import BaseModel
import numpy as np
from typing import List
app = FastAPI(
title="SBERT Embedding API",
description="API for generating sentence embeddings using SBERT",
version="1.0"
)
# Load model (this will be cached after first load)
model_name = 'taghyan/model'
model = SentenceTransformer(model_name)
# Embedding cache setup
embedding_file = 'embeddings_sbert.pkl'
class TextRequest(BaseModel):
text: str
class TextsRequest(BaseModel):
texts: List[str]
class EmbeddingResponse(BaseModel):
embedding: List[float]
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
@app.get("/")
def read_root():
return {"message": "SBERT Embedding Service"}
@app.post("/embed", response_model=EmbeddingResponse)
async def embed_text(request: TextRequest):
"""Generate embedding for a single text"""
embedding = model.encode(request.text, convert_to_numpy=True).tolist()
return {"embedding": embedding}
@app.post("/embed_batch", response_model=EmbeddingsResponse)
async def embed_texts(request: TextsRequest):
"""Generate embeddings for multiple texts"""
embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist()
return {"embeddings": embeddings}
@app.post("/update_cache")
async def update_cache(request: TextsRequest):
"""Update the embedding cache with new texts"""
if os.path.exists(embedding_file):
with open(embedding_file, 'rb') as f:
existing_embeddings = pickle.load(f)
else:
existing_embeddings = []
new_embeddings = model.encode(request.texts, show_progress_bar=True)
updated_embeddings = existing_embeddings + new_embeddings.tolist()
with open(embedding_file, 'wb') as f:
pickle.dump(updated_embeddings, f)
return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)} |