|
import os |
|
import json |
|
import boto3 |
|
import torch |
|
import argparse |
|
import time |
|
from omegaconf import OmegaConf |
|
|
|
from inference import inference_process |
|
|
|
def download_from_s3(s3_path, local_path): |
|
s3 = boto3.client('s3') |
|
bucket, key = s3_path.replace("s3://", "").split("/", 1) |
|
s3.download_file(bucket, key, local_path) |
|
|
|
def upload_to_s3(local_path, s3_path): |
|
s3 = boto3.client('s3') |
|
bucket, key = s3_path.replace("s3://", "").split("/", 1) |
|
s3.upload_file(local_path, bucket, key) |
|
|
|
def model_fn(model_dir): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model_dir |
|
|
|
def input_fn(request_body, content_type='application/json'): |
|
if content_type == 'application/json': |
|
input_data = json.loads(request_body) |
|
|
|
|
|
source_image_path = input_data['source_image'] |
|
driving_audio_path = input_data['driving_audio'] |
|
|
|
local_source_image = "/opt/ml/input/data/source_image.jpg" |
|
local_driving_audio = "/opt/ml/input/data/driving_audio.wav" |
|
|
|
if source_image_path.startswith("s3://"): |
|
download_from_s3(source_image_path, local_source_image) |
|
input_data['source_image'] = local_source_image |
|
if driving_audio_path.startswith("s3://"): |
|
download_from_s3(driving_audio_path, local_driving_audio) |
|
input_data['driving_audio'] = local_driving_audio |
|
|
|
args = argparse.Namespace(**input_data.get('config', {})) |
|
s3_output = input_data.get('output', None) |
|
|
|
return args, s3_output |
|
else: |
|
raise ValueError(f"Unsupported content type: {content_type}") |
|
|
|
def predict_fn(input_data, model): |
|
args, s3_output = input_data |
|
|
|
|
|
inference_process(args) |
|
|
|
return '.cache/output.mp4', s3_output |
|
|
|
def output_fn(prediction, content_type='application/json'): |
|
local_output, s3_output = prediction |
|
|
|
|
|
while not os.path.exists(local_output): |
|
time.sleep(1) |
|
|
|
return json.dumps({'status': 'completed', 's3_output': s3_output}) |
|
|