vectorAI / app.py
coder118
vector_meaningPool
fe60f1c
Raw
History Blame Contribute Delete
2.1 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from typing import List
app = FastAPI(title="GoToday Vector AI Service")
# 1. ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์ „์—ญ ๋กœ๋“œ (์„œ๋ฒ„๊ฐ€ ์ผœ์งˆ ๋•Œ ๋”ฑ ํ•œ ๋ฒˆ๋งŒ ๋‹ค์šด๋กœ๋“œ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ๋กœ๋“œ๋จ)
MODEL_NAME = "snunlp/KR-SBERT-V40K-klueNLI-augSTS"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# Pydantic์„ ์ด์šฉํ•œ ์š”์ฒญ ๋ฐ์ดํ„ฐ ๊ฒ€์ฆ ๊ทœ๊ฒฉ ์ •์˜
class EmbeddingRequest(BaseModel):
texts: List[str]
# Mean Pooling(ํ‰๊ท  ํ’€๋ง) ์—ฐ์‚ฐ
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # ์ง€์ •๋œ last_hidden_state ๊บผ๋‚ด๊ธฐ
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
@app.get("/")
def read_root():
return {"status": "healthy", "model": MODEL_NAME}
@app.post("/embedding")
def get_embeddings(request: EmbeddingRequest):
if not request.texts:
raise HTTPException(status_code=400, detail="ํ…์ŠคํŠธ ๋ฆฌ์ŠคํŠธ๊ฐ€ ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค.")
try:
# 2. ํ† ํฐํ™” ์—ฐ์‚ฐ ์ง„ํ–‰
encoded_input = tokenizer(
request.texts,
padding=True,
truncation=True,
return_tensors='pt'
)
# 3. ๋ชจ๋ธ ์ถ”๋ก  (CPU ํ™˜๊ฒฝ์ด๋ฏ€๋กœ ๋ฌด๊ฒ์ง€ ์•Š๊ฒŒ gradient ๊ณ„์‚ฐ ์ œ์™ธ)
with torch.no_grad():
model_output = model(**encoded_input)
# 4. attention_mask ๊ธฐ์ค€ ํ’€๋ง ์ž‘์—… ์ˆ˜ํ–‰
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# 5. ์Šคํ”„๋ง ๋ถ€ํŠธ๊ฐ€ ๋ฐ›๊ธฐ ํŽธํ•˜๋„๋ก ํŒŒ์ด์ฌ float ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ˜ํ™˜
return {"embeddings": embeddings.tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))