import os import json import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig # 모델과 토크나이저를 로드하는 함수 def model_fn(model_dir): """ SageMaker가 모델을 로드하기 위해 호출하는 함수 Args: model_dir (str): 모델 파일이 저장된 디렉토리 경로 Returns: dict: 모델, 토크나이저, 설정 등을 포함한 딕셔너리 """ # 환경 변수 설정 (선택 사항) os.environ["TOKENIZERS_PARALLELISM"] = "false" # 설정 파일 로드 config_path = os.path.join(model_dir, "config.json") config = AutoConfig.from_pretrained(config_path) print(f"Loading model from {model_dir}") print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") # 레이블 매핑 로드 (있는 경우) label_map = {} label_map_path = os.path.join(model_dir, "label_map.json") if os.path.exists(label_map_path): with open(label_map_path, 'r', encoding='utf-8') as f: label_map = json.load(f) print(f"Loaded label map from {label_map_path}") else: print("No label map found. Using numeric indices as labels.") # 모델 로드 model = AutoModelForSequenceClassification.from_pretrained( model_dir, config=config, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) # GPU 사용 가능한 경우 모델을 GPU로 이동 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # 토크나이저 로드 tokenizer = AutoTokenizer.from_pretrained(model_dir) return { "model": model, "tokenizer": tokenizer, "config": config, "device": device, "label_map": label_map } # 입력 데이터 처리 함수 def input_fn(request_body, request_content_type): """ SageMaker가 요청 데이터를 처리하기 위해 호출하는 함수 Args: request_body: 요청 본문 데이터 request_content_type (str): 요청 콘텐츠 타입 Returns: dict: 처리된 입력 데이터 """ if request_content_type == "application/json": input_data = json.loads(request_body) # 문자열인 경우 텍스트로 처리 if isinstance(input_data, str): return {"text": input_data} return input_data elif request_content_type == "text/plain": # 일반 텍스트 처리 return {"text": request_body.decode('utf-8')} else: raise ValueError(f"지원되지 않는 콘텐츠 타입: {request_content_type}") # 예측 함수 def predict_fn(input_data, model_dict): """ SageMaker가 모델 예측을 수행하기 위해 호출하는 함수 Args: input_data (dict): 처리된 입력 데이터 model_dict (dict): model_fn에서 반환한 모델 정보 Returns: dict: 예측 결과 """ model = model_dict["model"] tokenizer = model_dict["tokenizer"] device = model_dict["device"] label_map = model_dict["label_map"] # 입력 텍스트 가져오기 if "text" in input_data: text = input_data["text"] else: raise ValueError("입력 데이터에 'text' 필드가 없습니다") # 토큰화 옵션 max_length = input_data.get("max_length", 512) padding = input_data.get("padding", "max_length") truncation = input_data.get("truncation", True) # 토큰화 inputs = tokenizer( text, return_tensors="pt", padding=padding, truncation=truncation, max_length=max_length ) # 입력 텐서를 디바이스로 이동 inputs = {name: tensor.to(device) for name, tensor in inputs.items()} # 모델 추론 with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) # 이진 분류 모델인 경우 (클래스 수가 2인 경우) if logits.shape[1] == 2: positive_prob = probabilities[0, 1].item() negative_prob = probabilities[0, 0].item() prediction = 1 if positive_prob > 0.5 else 0 result = { "prediction": prediction, "positive_probability": positive_prob, "negative_probability": negative_prob } # 레이블 매핑이 있는 경우 레이블 추가 if label_map: pred_label = str(prediction) if pred_label in label_map: result["label"] = label_map[pred_label] # 다중 클래스 모델인 경우 else: predictions = torch.argmax(probabilities, dim=1).cpu().numpy().tolist() probabilities = probabilities.cpu().numpy().tolist()[0] result = { "prediction": predictions[0], "probabilities": probabilities, } # 레이블 매핑이 있는 경우 레이블 추가 if label_map: pred_label = str(predictions[0]) if pred_label in label_map: result["label"] = label_map[pred_label] # 모든 레이블에 대한 확률 매핑 추가 result["label_probabilities"] = { label_map.get(str(idx), str(idx)): prob for idx, prob in enumerate(probabilities) } return result # 출력 데이터 처리 함수 def output_fn(prediction, response_content_type): """ SageMaker가 예측 결과를 응답 형식으로 변환하기 위해 호출하는 함수 Args: prediction: predict_fn에서 반환한 예측 결과 response_content_type (str): 원하는 응답 콘텐츠 타입 Returns: str: 직렬화된 예측 결과 """ if response_content_type == "application/json": return json.dumps(prediction, ensure_ascii=False) else: raise ValueError(f"지원되지 않는 콘텐츠 타입: {response_content_type}")