|
import torch.nn.functional as F |
|
import torch |
|
from pinecone_text.sparse import SpladeEncoder |
|
import re |
|
from fastapi import FastAPI, Depends |
|
from fastapi_health import health |
|
from fastapi import FastAPI, Query |
|
from pydantic import BaseModel |
|
|
|
class TextPayload(BaseModel): |
|
text: str |
|
|
|
def get_session(): |
|
return True |
|
|
|
def is_database_online(session: bool = Depends(get_session)): |
|
return session |
|
|
|
app = FastAPI() |
|
app.add_api_route("/healthz", health([is_database_online])) |
|
|
|
class Load_EmbeddingModels: |
|
def __init__(self): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.sparse_model = SpladeEncoder(device=self.device) |
|
|
|
def get_single_sparse_text_embedding(self, df_chunk): |
|
return self.sparse_model.encode_documents(df_chunk) |
|
|
|
|
|
model = Load_EmbeddingModels() |
|
|
|
@app.post("/embed-text-sparse/") |
|
async def embed_text(payload: TextPayload): |
|
try: |
|
embeddings = model.get_single_sparse_text_embedding(payload.text) |
|
return embeddings |
|
except Exception as e: |
|
print(f'Error: {e}') |