VidTune / engine /video_descriptor.py
animikhaich's picture
Incomplete Update
d50bd1e
raw
history blame
5.72 kB
import os
from warnings import simplefilter
simplefilter("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time
import google.generativeai as genai
try:
from logger import logging
except:
import logging
music_prompt_examples = """
'A dynamic blend of hip-hop and orchestral elements, with sweeping strings and brass, evoking the vibrant energy of the city',
'Smooth jazz, with a saxophone solo, piano chords, and snare full drums',
'90s rock song with electric guitar and heavy drums'.
"""
json_schema = """
{"Content Description": "string", "Music Prompt": "string"}
"""
gemni_instructions = f"""
You are a music supervisor who analyzes the content and tone of images and videos to describe music that fits well with the mood, evokes emotions, and enhances the narrative of the visuals. Given an image or video, describe the scene and generate a prompt suitable for music generation models. Use keywords related to genre, instruments, mood, context, and setting to craft a concise single-sentence prompt, like:
{music_prompt_examples}
You must return your response using this JSON schema: {json_schema}
"""
class DescribeVideo:
def __init__(self, model="flash"):
self.model = self.get_model_name(model)
__api_key = self.load_api_key()
self.is_safety_set = False
self.safety_settings = self.get_safety_settings()
genai.configure(api_key=__api_key)
self.mllm_model = genai.GenerativeModel(self.model)
logging.info(f"Initialized DescribeVideo with model: {self.model}")
def describe_video(self, video_path):
video_file = genai.upload_file(video_path)
logging.info(f"Uploaded video: {video_path}")
while video_file.state.name == "PROCESSING":
time.sleep(0.25)
video_file = genai.get_file(video_file.name)
if video_file.state.name == "FAILED":
logging.error(f"Failed to upload video: {video_file.state.name}")
raise ValueError(f"Failed to upload video: {video_file.state.name}")
response = self.mllm_model.generate_content(
[video_file, "Explain what is happening in this video"],
request_options={"timeout": 600},
safety_settings=self.safety_settings,
)
logging.info(
f"Generated content for video: {video_path} with response: {response.text}"
)
cleaned_response = self.mllm_model.generate_content(
[
response.text,
gemni_instructions,
],
safety_settings=self.safety_settings,
)
logging.info(f"Generated : {video_path} with response: {cleaned_response.text}")
return json.loads(cleaned_response.text.strip("```json\n"))
def __call__(self, video_path):
return self.describe_video(video_path)
def reset_safety_settings(self):
logging.info("Resetting safety settings")
self.is_safety_set = False
self.safety_settings = self.get_safety_settings()
def set_safety_settings(self, safety_settings):
self.safety_settings = safety_settings
# Sanity Checks
if not isinstance(safety_settings, dict):
raise ValueError("Safety settings must be a dictionary")
for harm_category, harm_block_threshold in safety_settings.items():
if harm_category not in genai.types.HarmCategory.__members__:
raise ValueError(f"Invalid harm category: {harm_category}")
if harm_block_threshold not in genai.types.HarmBlockThreshold.__members__:
raise ValueError(
f"Invalid harm block threshold: {harm_block_threshold}"
)
logging.info(f"Set safety settings: {safety_settings}")
self.safety_settings = safety_settings
self.is_safety_set = True
def get_safety_settings(self):
default_safety_settings = {
genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
}
if self.is_safety_set:
return self.safety_settings
return default_safety_settings
@staticmethod
def load_api_key(path="./creds.json"):
with open(path) as f:
creds = json.load(f)
api_key = creds.get("google_api_key", None)
if api_key is None or not isinstance(api_key, str):
logging.error(f"Google API key not found in {path}")
raise ValueError(f"Gemini API key not found in {path}")
return api_key
@staticmethod
def get_model_name(model):
models = {
"flash": "models/gemini-1.5-flash-latest",
"pro": "models/gemini-1.5-pro-latest",
}
if model not in models:
logging.error(
f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
)
raise ValueError(
f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
)
logging.info(f"Selected model: {models[model]}")
return models[model]
if __name__ == "__main__":
video_path = "videos/3A49B385FD4A8FE2E3AEEF43C140D9AF_video_dashinit.mp4"
dv = DescribeVideo(model="flash")
video_description = dv.describe_video(video_path)
print(video_description)