hallo / scripts /sagemaker.py
Yohai Rosen
test
2a0635e
raw
history blame contribute delete
No virus
2.66 kB
import os
import json
import boto3
import torch
import argparse
import time
from omegaconf import OmegaConf
from inference import inference_process # Ensure inference.py is in the same directory or update the import path
def download_from_s3(s3_path, local_path):
s3 = boto3.client('s3')
bucket, key = s3_path.replace("s3://", "").split("/", 1)
s3.download_file(bucket, key, local_path)
def upload_to_s3(local_path, s3_path):
s3 = boto3.client('s3')
bucket, key = s3_path.replace("s3://", "").split("/", 1)
s3.upload_file(local_path, bucket, key)
def model_fn(model_dir):
# config_path = os.path.join(model_dir, 'config.json')
# # Create a placeholder config.json if it does not exist
# if not os.path.exists(config_path):
# print(f"config.json not found in {model_dir}. Creating a placeholder config.json.")
# config_content = {
# "placeholder": "This is a placeholder config.json"
# }
# with open(config_path, 'w') as config_file:
# json.dump(config_content, config_file)
return model_dir
def input_fn(request_body, content_type='application/json'):
if content_type == 'application/json':
input_data = json.loads(request_body)
# Download source_image and driving_audio from S3 if necessary
source_image_path = input_data['source_image']
driving_audio_path = input_data['driving_audio']
local_source_image = "/opt/ml/input/data/source_image.jpg"
local_driving_audio = "/opt/ml/input/data/driving_audio.wav"
if source_image_path.startswith("s3://"):
download_from_s3(source_image_path, local_source_image)
input_data['source_image'] = local_source_image
if driving_audio_path.startswith("s3://"):
download_from_s3(driving_audio_path, local_driving_audio)
input_data['driving_audio'] = local_driving_audio
args = argparse.Namespace(**input_data.get('config', {}))
s3_output = input_data.get('output', None)
return args, s3_output
else:
raise ValueError(f"Unsupported content type: {content_type}")
def predict_fn(input_data, model):
args, s3_output = input_data
# Call the inference process
inference_process(args)
return '.cache/output.mp4', s3_output
def output_fn(prediction, content_type='application/json'):
local_output, s3_output = prediction
# Wait for the output file to be created and upload it to S3
while not os.path.exists(local_output):
time.sleep(1)
return json.dumps({'status': 'completed', 's3_output': s3_output})