File size: 2,049 Bytes
cbd9add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI
from sentence_transformers import SentenceTransformer
import pickle
import os
from pydantic import BaseModel
import numpy as np
from typing import List

app = FastAPI(
    title="SBERT Embedding API",
    description="API for generating sentence embeddings using SBERT",
    version="1.0"
)

# Load model (this will be cached after first load)
model_name = 'taghyan/model'
model = SentenceTransformer(model_name)

# Embedding cache setup
embedding_file = 'embeddings_sbert.pkl'

class TextRequest(BaseModel):
    text: str

class TextsRequest(BaseModel):
    texts: List[str]

class EmbeddingResponse(BaseModel):
    embedding: List[float]

class EmbeddingsResponse(BaseModel):
    embeddings: List[List[float]]

@app.get("/")
def read_root():
    return {"message": "SBERT Embedding Service"}

@app.post("/embed", response_model=EmbeddingResponse)
async def embed_text(request: TextRequest):
    """Generate embedding for a single text"""
    embedding = model.encode(request.text, convert_to_numpy=True).tolist()
    return {"embedding": embedding}

@app.post("/embed_batch", response_model=EmbeddingsResponse)
async def embed_texts(request: TextsRequest):
    """Generate embeddings for multiple texts"""
    embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist()
    return {"embeddings": embeddings}

@app.post("/update_cache")
async def update_cache(request: TextsRequest):
    """Update the embedding cache with new texts"""
    if os.path.exists(embedding_file):
        with open(embedding_file, 'rb') as f:
            existing_embeddings = pickle.load(f)
    else:
        existing_embeddings = []
    
    new_embeddings = model.encode(request.texts, show_progress_bar=True)
    updated_embeddings = existing_embeddings + new_embeddings.tolist()
    
    with open(embedding_file, 'wb') as f:
        pickle.dump(updated_embeddings, f)
    
    return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)}