hiddenFront's picture
Update app.py
3cc319e verified
raw
history blame
6.78 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
import pickle
import gluonnlp as nlp
import numpy as np
import os
import sys # sys ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์ถ”๊ฐ€ (NameError ํ•ด๊ฒฐ)
# KoBERTTokenizer ๋Œ€์‹  transformers.AutoTokenizer ์‚ฌ์šฉ
from transformers import BertModel, AutoTokenizer # AutoTokenizer ์ž„ํฌํŠธ ์œ ์ง€
from torch.utils.data import Dataset, DataLoader
import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
from huggingface_hub import hf_hub_download # hf_hub_download ์ž„ํฌํŠธ ์ถ”๊ฐ€
import collections # collections ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
# --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ (model.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
class BERTClassifier(torch.nn.Module):
def __init__(self,
bert,
hidden_size = 768,
num_classes=5, # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜ (category ๋”•์…”๋„ˆ๋ฆฌ ํฌ๊ธฐ์™€ ์ผ์น˜)
dr_rate=None,
params=None):
super(BERTClassifier, self).__init__()
self.bert = bert
self.dr_rate = dr_rate
self.classifier = torch.nn.Linear(hidden_size , num_classes)
if dr_rate:
self.dropout = torch.nn.Dropout(p=dr_rate)
def gen_attention_mask(self, token_ids, valid_length):
attention_mask = torch.zeros_like(token_ids)
for i, v in enumerate(valid_length):
attention_mask[i][:v] = 1
return attention_mask.float()
def forward(self, token_ids, valid_length, segment_ids):
attention_mask = self.gen_attention_mask(token_ids, valid_length)
_, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
if self.dr_rate:
out = self.dropout(pooler)
else:
out = pooler
return self.classifier(out)
# --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ (dataset.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
# nlp.data.BERTSentenceTransform์€ ํ† ํฌ๋‚˜์ด์ € ํ•จ์ˆ˜๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค.
# AutoTokenizer์˜ tokenize ๋ฉ”์„œ๋“œ๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i],))
def __len__(self):
return len(self.labels)
# --- 3. FastAPI ์•ฑ ๋ฐ ์ „์—ญ ๋ณ€์ˆ˜ ์„ค์ • ---
app = FastAPI()
device = torch.device("cpu") # Render์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# โœ… category ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
try:
with open("category.pkl", "rb") as f:
category = pickle.load(f)
print("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… vocab ๋กœ๋“œ (GitHub ์ €์žฅ์†Œ ๋ฃจํŠธ์— ์žˆ์–ด์•ผ ํ•จ)
try:
with open("vocab.pkl", "rb") as f:
vocab = pickle.load(f)
print("vocab.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
except FileNotFoundError:
print("Error: vocab.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
sys.exit(1) # ํŒŒ์ผ ์—†์œผ๋ฉด ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (transformers.AutoTokenizer ์‚ฌ์šฉ)
# KoBERTTokenizer ๋Œ€์‹  AutoTokenizer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ KoBERT ๋ชจ๋ธ์˜ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด XLNetTokenizer ๊ฒฝ๊ณ  ๋ฐ kobert_tokenizer ์„ค์น˜ ๋ฌธ์ œ๋ฅผ ํ”ผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
# โœ… ๋ชจ๋ธ ๋กœ๋“œ
# num_classes๋Š” category ๋”•์…”๋„ˆ๋ฆฌ์˜ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
bertmodel = BertModel.from_pretrained('skt/kobert-base-v1')
model = BERTClassifier(
bertmodel,
dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
num_classes=len(category)
)
# textClassifierModel.pt ํŒŒ์ผ ๋กœ๋“œ
try:
HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
HF_MODEL_FILENAME = "textClassifierModel.pt"
model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
loaded_state_dict = torch.load(model_path, map_location=device)
new_state_dict = collections.OrderedDict()
for k, v in loaded_state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device) # ๋ชจ๋ธ์„ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
model.eval() # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
except Exception as e:
print(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1) # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
# โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
max_len = 64
batch_size = 32
# โœ… ์˜ˆ์ธก ํ•จ์ˆ˜
def predict(predict_sentence):
data = [predict_sentence, '0']
dataset_another = [data]
# num_workers๋Š” ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ 0์œผ๋กœ ์„ค์ • ๊ถŒ์žฅ
another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False) # tokenizer ๊ฐ์ฒด ์ง์ ‘ ์ „๋‹ฌ
test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0)
model.eval() # ์˜ˆ์ธก ์‹œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •
with torch.no_grad(): # ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
token_ids = token_ids.long().to(device)
segment_ids = segment_ids.long().to(device)
out = model(token_ids, valid_length, segment_ids)
logits = out
logits = logits.detach().cpu().numpy()
predicted_category_index = np.argmax(logits)
predicted_category_name = list(category.keys())[predicted_category_index]
return predicted_category_name
# โœ… ์—”๋“œํฌ์ธํŠธ ์ •์˜
class InputText(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Text Classification API (KoBERT)"}
@app.post("/predict")
async def predict_route(item: InputText):
result = predict(item.text)
return {"text": item.text, "classification": result}