| """ |
| SageMaker Multi-Model Endpoint inference script for GLiNER2. |
| |
| This script handles model loading and inference for the GLiNER2 Multi-Model Endpoint. |
| Models are loaded dynamically based on the TargetModel header in the request. |
| |
| Key differences from single-model inference: |
| - model_fn() receives the full path to the model directory (including model name) |
| - Models are cached automatically by SageMaker MME |
| - Multiple models can be loaded in memory simultaneously |
| - LRU eviction when memory is full |
| """ |
|
|
| import json |
| import os |
| import sys |
| import subprocess |
|
|
|
|
| def _ensure_gliner2_installed(): |
| """ |
| Ensure gliner2 is installed. Install it dynamically if missing. |
| |
| This is a workaround for SageMaker MME where requirements.txt |
| might not be installed automatically. |
| """ |
| try: |
| import gliner2 |
|
|
| print(f"[MME] gliner2 version {gliner2.__version__} already installed") |
| return True |
| except ImportError: |
| print("[MME] gliner2 not found, installing...") |
| try: |
| |
| |
| subprocess.check_call( |
| [ |
| sys.executable, |
| "-m", |
| "pip", |
| "install", |
| "--quiet", |
| "--no-cache-dir", |
| "gliner2==1.0.1", |
| "transformers>=4.30.0,<4.46.0", |
| ] |
| ) |
| print("[MME] ✓ gliner2 installed successfully") |
| return True |
| except subprocess.CalledProcessError as e: |
| print(f"[MME] ERROR: Failed to install gliner2: {e}") |
| return False |
|
|
|
|
| |
| _ensure_gliner2_installed() |
|
|
| import torch |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
| class DummyModel: |
| """Placeholder model for MME container initialization""" |
|
|
| def __call__(self, *args, **kwargs): |
| raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
| def extract_entities(self, *args, **kwargs): |
| raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
| def classify_text(self, *args, **kwargs): |
| raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
| def extract_json(self, *args, **kwargs): |
| raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
|
|
| def model_fn(model_dir): |
| """ |
| Load the GLiNER2 model from the model directory. |
| |
| For Multi-Model Endpoints, SageMaker passes the full path to the specific |
| model being loaded, e.g., /opt/ml/models/<model_name>/ |
| |
| Args: |
| model_dir: The directory where model artifacts are extracted |
| |
| Returns: |
| The loaded GLiNER2 model |
| """ |
| print(f"[MME] Loading model from: {model_dir}") |
| try: |
| print(f"[MME] Contents: {os.listdir(model_dir)}") |
| except Exception as e: |
| print(f"[MME] Could not list directory contents: {e}") |
|
|
| |
| try: |
| from gliner2 import GLiNER2 |
| except ImportError as e: |
| print(f"[MME] ERROR: gliner2 import failed: {e}") |
| print("[MME] Attempting to install gliner2...") |
| if _ensure_gliner2_installed(): |
| from gliner2 import GLiNER2 |
| else: |
| GLiNER2 = None |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"[MME] Using device: {device}") |
|
|
| if torch.cuda.is_available(): |
| print(f"[MME] GPU: {torch.cuda.get_device_name(0)}") |
| print(f"[MME] CUDA version: {torch.version.cuda}") |
| mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| print(f"[MME] GPU memory: {mem_gb:.2f} GB") |
|
|
| |
| hf_token = os.environ.get("HF_TOKEN") |
|
|
| |
| if os.path.exists(os.path.join(model_dir, "mme_container.txt")): |
| print("[MME] Container model detected - returning dummy model") |
| return DummyModel() |
|
|
| if GLiNER2 is None: |
| raise ImportError("gliner2 package required but not found") |
|
|
| |
| if os.path.exists(os.path.join(model_dir, "config.json")): |
| print("[MME] Loading model from extracted artifacts...") |
| model = GLiNER2.from_pretrained(model_dir, token=hf_token) |
| elif os.path.exists(os.path.join(model_dir, "download_at_runtime.txt")): |
| |
| print("[MME] Model not in archive, downloading from HuggingFace...") |
| model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") |
| print(f"[MME] Downloading model: {model_name}") |
| model = GLiNER2.from_pretrained(model_name, token=hf_token) |
| else: |
| |
| model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") |
| print(f"[MME] Model directory empty, downloading: {model_name}") |
| model = GLiNER2.from_pretrained(model_name, token=hf_token) |
|
|
| |
| print(f"[MME] Moving model to {device}...") |
| model = model.to(device) |
|
|
| |
| if torch.cuda.is_available(): |
| print("[MME] Converting to fp16...") |
| model = model.half() |
|
|
| |
| if torch.cuda.is_available(): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.cuda.empty_cache() |
| |
| torch.cuda.set_per_process_memory_fraction(0.85) |
| print("[MME] GPU memory optimizations enabled") |
|
|
| print(f"[MME] ✓ Model loaded successfully on {device}") |
| return model |
|
|
|
|
| def input_fn(request_body, request_content_type): |
| """ |
| Deserialize and prepare the input data for prediction. |
| |
| Args: |
| request_body: The request body |
| request_content_type: The content type of the request |
| |
| Returns: |
| Parsed input data as a dictionary |
| """ |
| if request_content_type == "application/json": |
| input_data = json.loads(request_body) |
| return input_data |
| else: |
| raise ValueError(f"Unsupported content type: {request_content_type}") |
|
|
|
|
| def predict_fn(input_data, model): |
| """ |
| Run prediction on the input data using the loaded model. |
| |
| Args: |
| input_data: Dictionary containing: |
| - task: One of 'extract_entities', 'classify_text', or 'extract_json' |
| - text: Text to process (string) or list of texts (for batch processing) |
| - schema: Schema for extraction (format depends on task) |
| - threshold: Optional confidence threshold (default: 0.5) |
| model: The loaded GLiNER2 model |
| |
| Returns: |
| Task-specific results (single result or list of results for batch) |
| """ |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| text = input_data.get("text") |
| task = input_data.get("task", "extract_entities") |
| schema = input_data.get("schema") |
| threshold = input_data.get("threshold", 0.5) |
|
|
| if not text: |
| raise ValueError("'text' field is required") |
| if not schema: |
| raise ValueError("'schema' field is required") |
|
|
| |
| is_batch = isinstance(text, list) |
|
|
| if is_batch and len(text) == 0: |
| raise ValueError("'text' list cannot be empty") |
|
|
| |
| with torch.inference_mode(): |
| if task == "extract_entities": |
| if is_batch: |
| if hasattr(model, "batch_extract_entities"): |
| result = model.batch_extract_entities( |
| text, schema, threshold=threshold |
| ) |
| elif hasattr(model, "batch_predict_entities"): |
| result = model.batch_predict_entities( |
| text, schema, threshold=threshold |
| ) |
| else: |
| result = [ |
| model.extract_entities(t, schema, threshold=threshold) |
| for t in text |
| ] |
| else: |
| result = model.extract_entities(text, schema, threshold=threshold) |
| return result |
|
|
| elif task == "classify_text": |
| if is_batch: |
| if hasattr(model, "batch_classify_text"): |
| result = model.batch_classify_text( |
| text, schema, threshold=threshold |
| ) |
| else: |
| result = [ |
| model.classify_text(t, schema, threshold=threshold) |
| for t in text |
| ] |
| else: |
| result = model.classify_text(text, schema, threshold=threshold) |
| return result |
|
|
| elif task == "extract_json": |
| if is_batch: |
| if hasattr(model, "batch_extract_json"): |
| result = model.batch_extract_json(text, schema, threshold=threshold) |
| else: |
| result = [ |
| model.extract_json(t, schema, threshold=threshold) for t in text |
| ] |
| else: |
| result = model.extract_json(text, schema, threshold=threshold) |
| return result |
|
|
| else: |
| raise ValueError( |
| f"Unsupported task: {task}. " |
| "Must be one of: extract_entities, classify_text, extract_json" |
| ) |
|
|
|
|
| def output_fn(prediction, response_content_type): |
| """ |
| Serialize the prediction output. |
| |
| Args: |
| prediction: The prediction result |
| response_content_type: The desired response content type |
| |
| Returns: |
| Serialized prediction |
| """ |
| if response_content_type == "application/json": |
| return json.dumps(prediction) |
| else: |
| raise ValueError(f"Unsupported response content type: {response_content_type}") |
|
|