audio_palette / utils /audio_palette.py
manasch's picture
Update ngrok regex match and add ngrok auth
2098fed verified
raw
history blame
No virus
6.04 kB
import re
import typing
from datetime import datetime, timezone, timedelta
import PIL
from PIL import Image
from moviepy.editor import *
from gradio import Error
from lib import *
datetime_format = "%d/%m/%Y %H:%M:%S"
ist_offset = timedelta(hours=5, minutes=30)
def now():
utc_time = datetime.now(timezone.utc)
ist_time = utc_time.astimezone(timezone(ist_offset))
return ist_time.strftime(datetime_format)
class AudioPalette:
def __init__(self, pace_model_weights_path, resnet50_tf_model_weights_path, height, width, channels):
self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path)
self.image_captioning = ImageCaptioning()
self.audio_generation = AudioGeneration()
self.sentiment_analyser = SentimentAnalyser()
self.pace_map = {
"Fast": "high",
"Medium": "medium",
"Slow": "low"
}
self.ngrok_url_pattern = re.compile("(https:\/\/[a-z0-9\-]+\.ngrok\.io\/)|(https:\/\/[a-z0-9\-]+\.ngrok-free.app\/)")
def prompt_construction(self, caption: str, pace: str, sentiment: typing.Union[str, None], instrument: typing.Union[str, None], first: bool = True):
instrument = instrument if instrument is not None else ""
if first:
prompt = f"A {instrument} soundtrack for {caption} with {self.pace_map[pace]} beats per minute. High Quality."
else:
prompt = f"A {instrument} soundtrack for {caption} with {self.pace_map[pace]} beats per minute. High Quality. Transitions smoothely from the previous audio while sounding different."
# if sentiment:
# prompt += f" As a {sentiment} music."
return prompt
def generate_single(self, input_image: PIL.Image.Image, instrument: typing.Union[str, None], ngrok_endpoint: typing.Union[str, None]):
if not self.ngrok_url_pattern.search(ngrok_endpoint):
print(f"[{now()}] Invalid ngrok endpoint - {ngrok_endpoint}")
raise Error(f"Invalid ngrok endpoint - {ngrok_endpoint}")
print(f"[{now()}] {ngrok_endpoint}")
pace = self.pace_model.predict(input_image)
print(f"[{now()}]", pace)
print(f"[{now()}] Pace Prediction Done")
try:
generated_text = self.image_captioning.query(input_image)[0].get("generated_text")
except Exception as e:
print(f"[{now()}] image captioning error")
raise Error(repr(e))
print(f"[{now()}]", generated_text)
print(f"[{now()}] Captioning Done")
sentiment = self.sentiment_analyser.sentiment(generated_text)
print(f"[{now()}] Sentiment Analysis Done")
prompt = self.prompt_construction(generated_text, pace, sentiment, instrument)
print(f"[{now()}] Generated Prompt:", prompt)
try:
audio_file = self.audio_generation.generate(prompt, ngrok_endpoint)
except Exception as e:
print(f"[{now()}] {e}")
raise Error(repr(e))
print(f"[{now()}]", audio_file)
print(f"[{now()}] Audio Generation Done")
outputs = [prompt, pace, generated_text, audio_file]
return outputs
def stitch_images(self, file_paths: typing.List[str], audio_paths: typing.List[str]):
clips = [ImageClip(m).set_duration(5) for m in file_paths]
audio_clips = [AudioFileClip(a) for a in audio_paths]
concat_audio = concatenate_audioclips(audio_clips)
new_audio = CompositeAudioClip([concat_audio])
concat_clip = concatenate_videoclips(clips, method="compose")
concat_clip.audio = new_audio
file_name = "generated_video.mp4"
concat_clip.write_videofile(file_name, fps=24)
return file_name
def generate_multiple(self, file_paths: typing.List[str], instrument: typing.Union[str, None], ngrok_endpoint: typing.Union[str, None]):
if not self.ngrok_url_pattern.search(ngrok_endpoint):
print(f"[{now()}] Invalid ngrok endpoint - {ngrok_endpoint}")
raise Error(f"Invalid ngrok endpoint - {ngrok_endpoint}")
print(f"[{now()}] {ngrok_endpoint}")
images = [Image.open(image_path) for image_path in file_paths]
pace = []
generated_text = []
sentiments = []
prompts = []
# Extracting the pace for all the images
for image in images:
pace_prediction = self.pace_model.predict(image)
pace.append(pace_prediction)
print(f"[{now()}]", pace)
print(f"[{now()}] Pace Prediction Done")
# Generating the caption for all the images
try:
for image in images:
caption = self.image_captioning.query(image)[0].get("generated_text")
generated_text.append(caption)
except Exception as e:
print(f"[{now()}] image captioning error")
raise Error(repr(e))
print(f"[{now()}]", generated_text)
print(f"[{now()}] Captioning Done")
# Extracting the sentiments from the generated captions
for text in generated_text:
sentiment = self.sentiment_analyser.sentiment(text)
sentiments.append(sentiment)
print(f"[{now()}] Sentiment Analysis Done:", sentiments)
first = True
for generated_caption, senti, pace_pred in zip(generated_text, sentiments, pace):
prompts.append(self.prompt_construction(generated_caption, pace_pred, senti, instrument, first))
first = False
print(f"[{now()}] Generated Prompts:", prompts)
try:
audio_file = self.audio_generation.generate(prompts, ngrok_endpoint)
except Exception as e:
print(f"[{now()}] {e}")
raise Error(repr(e))
print(f"[{now()}]", audio_file)
print(f"[{now()}] Audio Generation Done")
video_file = self.stitch_images(file_paths, [audio_file])
return video_file