| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModel |
| import torch |
| import torch.nn.functional as F |
| from typing import List |
|
|
| app = FastAPI(title="GoToday Vector AI Service") |
|
|
| |
| MODEL_NAME = "snunlp/KR-SBERT-V40K-klueNLI-augSTS" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModel.from_pretrained(MODEL_NAME) |
|
|
| |
| class EmbeddingRequest(BaseModel): |
| texts: List[str] |
|
|
| |
| def mean_pooling(model_output, attention_mask): |
| token_embeddings = model_output[0] |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
| @app.get("/") |
| def read_root(): |
| return {"status": "healthy", "model": MODEL_NAME} |
|
|
| @app.post("/embedding") |
| def get_embeddings(request: EmbeddingRequest): |
| if not request.texts: |
| raise HTTPException(status_code=400, detail="ํ
์คํธ ๋ฆฌ์คํธ๊ฐ ๋น์ด์์ต๋๋ค.") |
| |
| try: |
| |
| encoded_input = tokenizer( |
| request.texts, |
| padding=True, |
| truncation=True, |
| return_tensors='pt' |
| ) |
| |
| |
| with torch.no_grad(): |
| model_output = model(**encoded_input) |
| |
| |
| embeddings = mean_pooling(model_output, encoded_input['attention_mask']) |
| |
| |
| return {"embeddings": embeddings.tolist()} |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|