# Import necessary libraries import gradio as gr import pixeltable as pxt import os import getpass from pixeltable.functions.video import extract_audio from pixeltable.functions import openai as pxop import openai # Set up Pixeltable database and table db_directory = "video_db" table_name = "video_table" # Define constants MAX_VIDEO_SIZE_MB = 35 GPT_MODEL = "gpt-4o-mini-2024-07-18" MAX_TOKENS = 500 WHISPER_MODEL = "whisper-1" # Set OpenAI API key if "OPENAI_API_KEY" not in os.environ: os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:") # Clean up existing database and table if they exist pxt.drop_dir("video_db", force=True) if table_name in pxt.list_tables(): pxt.drop_table("video_db.video_table") # Create or use existing directory and table if db_directory not in pxt.list_dirs(): pxt.create_dir(db_directory) else: print(f"Directory {db_directory} already exists. Using the existing directory.") if table_name not in pxt.list_tables(): t = pxt.create_table( f"{db_directory}.{table_name}", { "video": pxt.VideoType(), "video_filename": pxt.StringType(), "sm_type": pxt.StringType(), "sm_post": pxt.StringType(), }, ) else: t = pxt.load_table(f"{db_directory}.{table_name}") print(f"Table {table_name} already exists. Using the existing table.") # Function to generate social media post using OpenAI GPT-4 API def generate_social_media_post(transcript_text, social_media_type): response = openai.chat.completions.create( model=GPT_MODEL, messages=[ { "role": "system", "content": f"You are an expert in creating social media content for {social_media_type}.", }, { "role": "user", "content": f"Generate an effective and casual social media post based on this video transcript below. Make it a viral and suitable post for {social_media_type}. Transcript:\n{transcript_text}.", }, ], max_tokens=MAX_TOKENS, ) return response.choices[0].message.content # Function to process the uploaded video and generate the post def process_and_generate_post(video_file, social_media_type): if video_file: try: # Check video file size video_size = os.path.getsize(video_file) / (1024 * 1024) # Convert to MB if video_size > MAX_VIDEO_SIZE_MB: return f"The video file is larger than {MAX_VIDEO_SIZE_MB} MB. Please upload a smaller file." video_filename = os.path.basename(video_file) tr_audio_gen_flag = True sm_gen_flag = True # Check if video already exists in the table video_df = t.where(t.video_filename == video_filename).tail(1) if t.select().where(t.video_filename == video_filename).count() >= 1: tr_audio_gen_flag = False # Check if video and social media type combination exists video_type_df = t.where( (t.video_filename == video_filename) & (t.sm_type == social_media_type) ).tail(1) if video_type_df: sm_gen_flag = False # Insert video into PixelTable if it doesn't exist or if it's a new social media type if ( (t.count() < 1) or not ( t.select().where(t.video_filename == video_filename).count() >= 1 ) or (video_df and not video_type_df) ): t.insert( [ { "video": video_file, "video_filename": video_filename, "sm_type": social_media_type, "sm_post": "", } ] ) # Extract audio and transcribe if needed if tr_audio_gen_flag: if not t.get_column(name="audio"): t["audio"] = extract_audio(t.video, format="mp3") else: t.audio = extract_audio(t.video, format="mp3") print("########### processing transcription #############") if not t.get_column(name="transcription"): t["transcription"] = pxop.transcriptions( t.audio, model=WHISPER_MODEL ) else: t.transcription = pxop.transcriptions(t.audio, model=WHISPER_MODEL) # Get the current video data filtered_df = t.where( (t.video_filename == video_filename) & (t.sm_type == social_media_type) ).tail(1) if len(filtered_df) == 0: return "No matching video found in the table. Please ensure the video is uploaded correctly and try again." cur_video_df = filtered_df[0] plain_text = cur_video_df["transcription"]["text"] # Generate or retrieve social media post if ( t.select() .where( (t.video_filename == video_filename) & (t.sm_type == social_media_type) & (t.sm_post != "") ) .count() >= 1 ): print("retrieving existing social media post") social_media_post = ( t.select(t.sm_post) .where( (t.sm_type == social_media_type) & (t.video_filename == video_filename) ) .collect()["sm_post"] ) else: print("generating new social media post") social_media_post = generate_social_media_post( plain_text, social_media_type ) if sm_gen_flag: cur_video_df.update({"sm_post": social_media_post}) return cur_video_df["sm_post"] except Exception as e: return f"An error occurred: {e}" else: return "Please upload a video file." # Gradio Interface def gradio_interface(): with gr.Blocks(theme=gr.themes.Glass()) as demo: # Set up the UI components gr.Markdown( """