|
print("importing runpod") |
|
import runpod |
|
print("importing requests") |
|
import requests |
|
print("importing generate_wav") |
|
from voice_generation import generate_wav |
|
print("importing boto3") |
|
import boto3 |
|
print("importing os") |
|
import os |
|
print("importing uuid") |
|
import uuid |
|
print("importing pydub") |
|
from pydub import AudioSegment |
|
import time |
|
import subprocess |
|
|
|
print("setting up environment variables") |
|
|
|
|
|
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID') |
|
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY') |
|
|
|
|
|
models = { |
|
'kanye': 'weights/kanye.pth', |
|
'rose-bp': 'weights/rose-bp.pth', |
|
'jungkook': 'weights/jungkook.pth', |
|
'iu': 'weights/iu.pth', |
|
'drake': 'weights/drake.pth', |
|
'ariana-grande': 'weights/ariana-grande.pth' |
|
} |
|
|
|
|
|
print('run handler') |
|
|
|
|
|
def split_audio(): |
|
subprocess.call(["deezer-spleeter-env/bin/python", "/deezer-split.py"]) |
|
|
|
|
|
def combine_audio(voice_path, instrumental_path): |
|
audio1 = AudioSegment.from_file(instrumental_path, format="mp3") |
|
audio2 = AudioSegment.from_file(voice_path, format="mp3") |
|
|
|
length = max(len(audio1), len(audio2)) |
|
audio1 = audio1 + AudioSegment.silent(duration=length - len(audio1)) |
|
audio2 = audio2 + AudioSegment.silent(duration=length - len(audio2)) |
|
|
|
combined = audio1.overlay(audio2) |
|
|
|
combined.export("combined.mp3", format="mp3") |
|
|
|
|
|
def upload_file_to_s3(local_file_path, s3_file_path): |
|
bucket_name = 'voice-gen-audios' |
|
s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY) |
|
try: |
|
s3.upload_file(local_file_path, bucket_name, s3_file_path) |
|
return {"url": f"https://{bucket_name}.s3.eu-north-1.amazonaws.com/{s3_file_path}"} |
|
except boto3.exceptions.S3UploadFailedError as e: |
|
return {"error": f"failed to upload file {local_file_path} to s3 as {s3_file_path}"} |
|
|
|
|
|
def clean_up_files(remove_voice_model=False): |
|
files = [ |
|
"song.mp3", |
|
"accompaniment.mp3", |
|
"vocals.mp3", |
|
"output_vocal.wav", |
|
"combined.mp3", |
|
] |
|
if remove_voice_model: |
|
files.append("voice_model.pth") |
|
for file in files: |
|
try: |
|
os.remove(file) |
|
except FileNotFoundError: |
|
return {"error": f"failed to remove file {file}"} |
|
return {"success": "files removed successfully"} |
|
|
|
|
|
def get_voice_model(event): |
|
voice_model_id = event["input"].get("voice_model_id", "") |
|
voice_model_url = event["input"].get("voice_model_url", "") |
|
|
|
if not voice_model_url and not voice_model_id: |
|
return {"error": "voice_model_url or voice_model_id is required"} |
|
|
|
if voice_model_id and voice_model_id not in models: |
|
return {"error": "model not found in pre-loaded models"} |
|
|
|
if voice_model_id: |
|
return {"model_path": models[voice_model_id]} |
|
|
|
print("downloading voice_model") |
|
voice_model_response = requests.get(voice_model_url) |
|
if voice_model_response.status_code != 200: |
|
return {"error": f"failed to download voice_model, error: {voice_model_response.text}"} |
|
|
|
with open("voice_model.pth", "wb") as f: |
|
f.write(voice_model_response.content) |
|
|
|
return {"model_path": "voice_model.pth"} |
|
|
|
|
|
def handler(event): |
|
print(event) |
|
file_id = str(uuid.uuid4()) |
|
user_id = event["input"].get("user_id", "not provided") |
|
|
|
if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY: |
|
return {"error": "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are missing from environment variables"} |
|
|
|
voice_model = get_voice_model(event) |
|
if "error" in voice_model: |
|
return voice_model.get("error") |
|
|
|
song_url = event["input"].get("song_url", "") |
|
|
|
if song_url == "": |
|
return {"error": "voice_url is required"} |
|
|
|
song_file = requests.get(song_url) |
|
if song_file.status_code != 200: |
|
return {"error": "failed to download song_file"} |
|
|
|
with open("song.mp3", "wb") as f: |
|
f.write(song_file.content) |
|
|
|
splitting_start = time.time() |
|
split_audio() |
|
splitting_end = time.time() |
|
time_taken_splitting = splitting_end - splitting_start |
|
print(f"splitting took {time_taken_splitting} seconds") |
|
|
|
if not os.path.exists("accompaniment.mp3") or not os.path.exists("vocals.mp3"): |
|
return {"error": "failed to split song"} |
|
|
|
|
|
|
|
song_instruments = upload_file_to_s3("accompaniment.mp3", f"{file_id}-split-accompaniment.mp3") |
|
song_vocals = upload_file_to_s3("vocals.mp3", f"{file_id}-split-vocals.mp3") |
|
if "error" in song_instruments: |
|
return song_instruments.get("error") |
|
if "error" in song_vocals: |
|
return song_vocals.get("error") |
|
|
|
|
|
gemeration_start = time.time() |
|
|
|
generation = generate_wav( |
|
audio_file='vocals.mp3', |
|
method='pm', |
|
index_rate=0.6, |
|
output_file='output_vocal.wav', |
|
model_path=voice_model.get("model_path") |
|
) |
|
generation_end = time.time() |
|
time_taken_generation = generation_end - gemeration_start |
|
print(f"generation took {time_taken_generation} seconds") |
|
|
|
if "error" in generation: |
|
return generation.get("error") |
|
|
|
combine_audio("output_vocal.wav", "accompaniment.mp3") |
|
|
|
if not os.path.exists("combined.mp3"): |
|
return {"error": "failed to combine audio"} |
|
|
|
combined = upload_file_to_s3("combined.mp3", f"{file_id}.mp3") |
|
output_voice = upload_file_to_s3("output_vocal.wav", f"{file_id}-generated-voical.mp3") |
|
|
|
if combined_error := combined.get("error"): |
|
return combined_error |
|
|
|
if output_voice_error := output_voice.get("error"): |
|
return output_voice_error |
|
|
|
combined_url = combined.get("url") |
|
output_voice_url = output_voice.get("url") |
|
|
|
need_to_remove_voice_model = False |
|
if voice_model.get("model_path") == "voice_model.pth": |
|
need_to_remove_voice_model = True |
|
cleanup_result = clean_up_files(need_to_remove_voice_model) |
|
if cleanup_error := cleanup_result.get("error"): |
|
return cleanup_error |
|
|
|
return { |
|
"combined_url": combined_url, |
|
"output_voice_url": output_voice_url, |
|
"user_id": user_id, |
|
"time_taken_splitting": time_taken_splitting, |
|
"time_taken_generation": time_taken_generation, |
|
} |
|
|
|
|
|
runpod.serverless.start({"handler": handler}) |
|
|