|
import os
|
|
import json
|
|
import torch
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
|
|
|
|
|
def model_fn(model_dir):
|
|
"""
|
|
SageMaker๊ฐ ๋ชจ๋ธ์ ๋ก๋ํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
|
Args:
|
|
model_dir (str): ๋ชจ๋ธ ํ์ผ์ด ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
|
|
|
|
Returns:
|
|
dict: ๋ชจ๋ธ, ํ ํฌ๋์ด์ , ์ค์ ๋ฑ์ ํฌํจํ ๋์
๋๋ฆฌ
|
|
"""
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
config_path = os.path.join(model_dir, "config.json")
|
|
config = AutoConfig.from_pretrained(config_path)
|
|
|
|
print(f"Loading model from {model_dir}")
|
|
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
|
|
|
|
|
|
label_map = {}
|
|
label_map_path = os.path.join(model_dir, "label_map.json")
|
|
if os.path.exists(label_map_path):
|
|
with open(label_map_path, 'r', encoding='utf-8') as f:
|
|
label_map = json.load(f)
|
|
print(f"Loaded label map from {label_map_path}")
|
|
else:
|
|
print("No label map found. Using numeric indices as labels.")
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
model_dir,
|
|
config=config,
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
|
)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
|
|
return {
|
|
"model": model,
|
|
"tokenizer": tokenizer,
|
|
"config": config,
|
|
"device": device,
|
|
"label_map": label_map
|
|
}
|
|
|
|
|
|
def input_fn(request_body, request_content_type):
|
|
"""
|
|
SageMaker๊ฐ ์์ฒญ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
|
Args:
|
|
request_body: ์์ฒญ ๋ณธ๋ฌธ ๋ฐ์ดํฐ
|
|
request_content_type (str): ์์ฒญ ์ฝํ
์ธ ํ์
|
|
|
|
Returns:
|
|
dict: ์ฒ๋ฆฌ๋ ์
๋ ฅ ๋ฐ์ดํฐ
|
|
"""
|
|
if request_content_type == "application/json":
|
|
input_data = json.loads(request_body)
|
|
|
|
|
|
if isinstance(input_data, str):
|
|
return {"text": input_data}
|
|
|
|
return input_data
|
|
|
|
elif request_content_type == "text/plain":
|
|
|
|
return {"text": request_body.decode('utf-8')}
|
|
|
|
else:
|
|
raise ValueError(f"์ง์๋์ง ์๋ ์ฝํ
์ธ ํ์
: {request_content_type}")
|
|
|
|
|
|
def predict_fn(input_data, model_dict):
|
|
"""
|
|
SageMaker๊ฐ ๋ชจ๋ธ ์์ธก์ ์ํํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
|
Args:
|
|
input_data (dict): ์ฒ๋ฆฌ๋ ์
๋ ฅ ๋ฐ์ดํฐ
|
|
model_dict (dict): model_fn์์ ๋ฐํํ ๋ชจ๋ธ ์ ๋ณด
|
|
|
|
Returns:
|
|
dict: ์์ธก ๊ฒฐ๊ณผ
|
|
"""
|
|
model = model_dict["model"]
|
|
tokenizer = model_dict["tokenizer"]
|
|
device = model_dict["device"]
|
|
label_map = model_dict["label_map"]
|
|
|
|
|
|
if "text" in input_data:
|
|
text = input_data["text"]
|
|
else:
|
|
raise ValueError("์
๋ ฅ ๋ฐ์ดํฐ์ 'text' ํ๋๊ฐ ์์ต๋๋ค")
|
|
|
|
|
|
max_length = input_data.get("max_length", 512)
|
|
padding = input_data.get("padding", "max_length")
|
|
truncation = input_data.get("truncation", True)
|
|
|
|
|
|
inputs = tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding=padding,
|
|
truncation=truncation,
|
|
max_length=max_length
|
|
)
|
|
|
|
|
|
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
logits = outputs.logits
|
|
probabilities = torch.softmax(logits, dim=1)
|
|
|
|
|
|
if logits.shape[1] == 2:
|
|
positive_prob = probabilities[0, 1].item()
|
|
negative_prob = probabilities[0, 0].item()
|
|
prediction = 1 if positive_prob > 0.5 else 0
|
|
|
|
result = {
|
|
"prediction": prediction,
|
|
"positive_probability": positive_prob,
|
|
"negative_probability": negative_prob
|
|
}
|
|
|
|
|
|
if label_map:
|
|
pred_label = str(prediction)
|
|
if pred_label in label_map:
|
|
result["label"] = label_map[pred_label]
|
|
|
|
|
|
else:
|
|
predictions = torch.argmax(probabilities, dim=1).cpu().numpy().tolist()
|
|
probabilities = probabilities.cpu().numpy().tolist()[0]
|
|
|
|
result = {
|
|
"prediction": predictions[0],
|
|
"probabilities": probabilities,
|
|
}
|
|
|
|
|
|
if label_map:
|
|
pred_label = str(predictions[0])
|
|
if pred_label in label_map:
|
|
result["label"] = label_map[pred_label]
|
|
|
|
|
|
result["label_probabilities"] = {
|
|
label_map.get(str(idx), str(idx)): prob
|
|
for idx, prob in enumerate(probabilities)
|
|
}
|
|
|
|
return result
|
|
|
|
|
|
def output_fn(prediction, response_content_type):
|
|
"""
|
|
SageMaker๊ฐ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์๋ต ํ์์ผ๋ก ๋ณํํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์
|
|
|
|
Args:
|
|
prediction: predict_fn์์ ๋ฐํํ ์์ธก ๊ฒฐ๊ณผ
|
|
response_content_type (str): ์ํ๋ ์๋ต ์ฝํ
์ธ ํ์
|
|
|
|
Returns:
|
|
str: ์ง๋ ฌํ๋ ์์ธก ๊ฒฐ๊ณผ
|
|
"""
|
|
if response_content_type == "application/json":
|
|
return json.dumps(prediction, ensure_ascii=False)
|
|
else:
|
|
raise ValueError(f"์ง์๋์ง ์๋ ์ฝํ
์ธ ํ์
: {response_content_type}") |