stp-classifier-4-4 / code /inference.py
srpsrpsrp's picture
Upload folder using huggingface_hub
e96bc27 verified
import os
import json
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋กœ๋“œํ•˜๋Š” ํ•จ์ˆ˜
def model_fn(model_dir):
"""
SageMaker๊ฐ€ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ธฐ ์œ„ํ•ด ํ˜ธ์ถœํ•˜๋Š” ํ•จ์ˆ˜
Args:
model_dir (str): ๋ชจ๋ธ ํŒŒ์ผ์ด ์ €์žฅ๋œ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
Returns:
dict: ๋ชจ๋ธ, ํ† ํฌ๋‚˜์ด์ €, ์„ค์ • ๋“ฑ์„ ํฌํ•จํ•œ ๋”•์…”๋„ˆ๋ฆฌ
"""
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ • (์„ ํƒ ์‚ฌํ•ญ)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ์„ค์ • ํŒŒ์ผ ๋กœ๋“œ
config_path = os.path.join(model_dir, "config.json")
config = AutoConfig.from_pretrained(config_path)
print(f"Loading model from {model_dir}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
# ๋ ˆ์ด๋ธ” ๋งคํ•‘ ๋กœ๋“œ (์žˆ๋Š” ๊ฒฝ์šฐ)
label_map = {}
label_map_path = os.path.join(model_dir, "label_map.json")
if os.path.exists(label_map_path):
with open(label_map_path, 'r', encoding='utf-8') as f:
label_map = json.load(f)
print(f"Loaded label map from {label_map_path}")
else:
print("No label map found. Using numeric indices as labels.")
# ๋ชจ๋ธ ๋กœ๋“œ
model = AutoModelForSequenceClassification.from_pretrained(
model_dir,
config=config,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
# GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ๋ชจ๋ธ์„ GPU๋กœ ์ด๋™
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return {
"model": model,
"tokenizer": tokenizer,
"config": config,
"device": device,
"label_map": label_map
}
# ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def input_fn(request_body, request_content_type):
"""
SageMaker๊ฐ€ ์š”์ฒญ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•ด ํ˜ธ์ถœํ•˜๋Š” ํ•จ์ˆ˜
Args:
request_body: ์š”์ฒญ ๋ณธ๋ฌธ ๋ฐ์ดํ„ฐ
request_content_type (str): ์š”์ฒญ ์ฝ˜ํ…์ธ  ํƒ€์ž…
Returns:
dict: ์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ
"""
if request_content_type == "application/json":
input_data = json.loads(request_body)
# ๋ฌธ์ž์—ด์ธ ๊ฒฝ์šฐ ํ…์ŠคํŠธ๋กœ ์ฒ˜๋ฆฌ
if isinstance(input_data, str):
return {"text": input_data}
return input_data
elif request_content_type == "text/plain":
# ์ผ๋ฐ˜ ํ…์ŠคํŠธ ์ฒ˜๋ฆฌ
return {"text": request_body.decode('utf-8')}
else:
raise ValueError(f"์ง€์›๋˜์ง€ ์•Š๋Š” ์ฝ˜ํ…์ธ  ํƒ€์ž…: {request_content_type}")
# ์˜ˆ์ธก ํ•จ์ˆ˜
def predict_fn(input_data, model_dict):
"""
SageMaker๊ฐ€ ๋ชจ๋ธ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ํ˜ธ์ถœํ•˜๋Š” ํ•จ์ˆ˜
Args:
input_data (dict): ์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ
model_dict (dict): model_fn์—์„œ ๋ฐ˜ํ™˜ํ•œ ๋ชจ๋ธ ์ •๋ณด
Returns:
dict: ์˜ˆ์ธก ๊ฒฐ๊ณผ
"""
model = model_dict["model"]
tokenizer = model_dict["tokenizer"]
device = model_dict["device"]
label_map = model_dict["label_map"]
# ์ž…๋ ฅ ํ…์ŠคํŠธ ๊ฐ€์ ธ์˜ค๊ธฐ
if "text" in input_data:
text = input_data["text"]
else:
raise ValueError("์ž…๋ ฅ ๋ฐ์ดํ„ฐ์— 'text' ํ•„๋“œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
# ํ† ํฐํ™” ์˜ต์…˜
max_length = input_data.get("max_length", 512)
padding = input_data.get("padding", "max_length")
truncation = input_data.get("truncation", True)
# ํ† ํฐํ™”
inputs = tokenizer(
text,
return_tensors="pt",
padding=padding,
truncation=truncation,
max_length=max_length
)
# ์ž…๋ ฅ ํ…์„œ๋ฅผ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
# ๋ชจ๋ธ ์ถ”๋ก 
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
# ์ด์ง„ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ (ํด๋ž˜์Šค ์ˆ˜๊ฐ€ 2์ธ ๊ฒฝ์šฐ)
if logits.shape[1] == 2:
positive_prob = probabilities[0, 1].item()
negative_prob = probabilities[0, 0].item()
prediction = 1 if positive_prob > 0.5 else 0
result = {
"prediction": prediction,
"positive_probability": positive_prob,
"negative_probability": negative_prob
}
# ๋ ˆ์ด๋ธ” ๋งคํ•‘์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋ ˆ์ด๋ธ” ์ถ”๊ฐ€
if label_map:
pred_label = str(prediction)
if pred_label in label_map:
result["label"] = label_map[pred_label]
# ๋‹ค์ค‘ ํด๋ž˜์Šค ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ
else:
predictions = torch.argmax(probabilities, dim=1).cpu().numpy().tolist()
probabilities = probabilities.cpu().numpy().tolist()[0]
result = {
"prediction": predictions[0],
"probabilities": probabilities,
}
# ๋ ˆ์ด๋ธ” ๋งคํ•‘์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋ ˆ์ด๋ธ” ์ถ”๊ฐ€
if label_map:
pred_label = str(predictions[0])
if pred_label in label_map:
result["label"] = label_map[pred_label]
# ๋ชจ๋“  ๋ ˆ์ด๋ธ”์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋งคํ•‘ ์ถ”๊ฐ€
result["label_probabilities"] = {
label_map.get(str(idx), str(idx)): prob
for idx, prob in enumerate(probabilities)
}
return result
# ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def output_fn(prediction, response_content_type):
"""
SageMaker๊ฐ€ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ์‘๋‹ต ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•ด ํ˜ธ์ถœํ•˜๋Š” ํ•จ์ˆ˜
Args:
prediction: predict_fn์—์„œ ๋ฐ˜ํ™˜ํ•œ ์˜ˆ์ธก ๊ฒฐ๊ณผ
response_content_type (str): ์›ํ•˜๋Š” ์‘๋‹ต ์ฝ˜ํ…์ธ  ํƒ€์ž…
Returns:
str: ์ง๋ ฌํ™”๋œ ์˜ˆ์ธก ๊ฒฐ๊ณผ
"""
if response_content_type == "application/json":
return json.dumps(prediction, ensure_ascii=False)
else:
raise ValueError(f"์ง€์›๋˜์ง€ ์•Š๋Š” ์ฝ˜ํ…์ธ  ํƒ€์ž…: {response_content_type}")