hiddenFront's picture
Update app.py
e66afc2 verified
raw
history blame
2.35 kB
from fastapi import FastAPI, Request
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import pickle
import os
import psutil
import sys
app = FastAPI()
device = torch.device("cpu")
# category.pkl λ‘œλ“œ
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("βœ… category.pkl λ‘œλ“œ 성곡.")
except FileNotFoundError:
print("❌ Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘œμ νŠΈ λ£¨νŠΈμ— μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
sys.exit(1)
# ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
print("βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
HF_MODEL_FILENAME = "textClassifierModel.pt"
# λ©”λͺ¨λ¦¬ 확인
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / (1024 * 1024)
print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before:.2f} MB")
# λͺ¨λΈ λ‹€μš΄λ‘œλ“œ 및 λ‘œλ“œ
try:
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"βœ… λͺ¨λΈ 파일 λ‹€μš΄λ‘œλ“œ 성곡: {model_path}")
mem_after_dl = process.memory_info().rss / (1024 * 1024)
print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_dl:.2f} MB")
model = torch.load(model_path, map_location=device) # 전체 λͺ¨λΈ 객체 λ‘œλ“œ
model.eval()
mem_after_load = process.memory_info().rss / (1024 * 1024)
print(f"πŸ“¦ λͺ¨λΈ λ‘œλ“œ ν›„ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_after_load:.2f} MB")
print("βœ… λͺ¨λΈ λ‘œλ“œ 성곡")
except Exception as e:
print(f"❌ Error: λͺ¨λΈ λ‹€μš΄λ‘œλ“œ λ˜λŠ” λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
sys.exit(1)
# 예츑 API
@app.post("/predict")
async def predict_api(request: Request):
data = await request.json()
text = data.get("text")
if not text:
return {"error": "No text provided", "classification": "null"}
encoded = tokenizer.encode_plus(
text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
)
with torch.no_grad():
outputs = model(**encoded)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted = torch.argmax(probs, dim=1).item()
label = list(category.keys())[predicted]
return {"text": text, "classification": label}