import os import openai import streamlit as st from streamlit_chat import message from config import output_path_video, output_path_transcription from keyword_retriever.keyword_retreiver import MediaRetriever from logger import logger from resource_loader.uploaded_media_loader import UploadedMediaLoader from resource_loader.youtube_loader import YouTubeLoader from summarization_service.summarizer import TranscriptSummary from utils import check_file_exists, download_video, transcribe_video, load_transcription st.set_page_config(page_title="Summary", layout="wide") # Initialize chat history chat_history = [] # Initialize variables for LLM options and chosen LLM llm_options = [] chosen_LLM = "default" def generate_response(prompt_input): answer = transcript_summary.query_summary(prompt_input) return answer @st.cache_resource() def factory_transcript(media_id, model, llm_provider): ts = TranscriptSummary(doc_id=media_id, model=model, llm_provider=llm_provider) logger.info("TranscriptSummary initialized") return ts @st.cache_resource() def factory_media(media_id, top_k): retriever = MediaRetriever(media_id=media_id, similarity_top_k=top_k) logger.info("video_retriever initialized") return retriever with st.sidebar: # Sidebar st.title("Controls") # Create a sidebar for the YouTube URL, search bar, and settings youtube_url = st.text_input("Enter YouTube URL:") uploaded_file = st.file_uploader("Or upload a video...", type=['mp4', 'mov', 'avi', 'flv', 'mkv', 'mp3', 'wav', 'aac', 'ogg']) if uploaded_file is not None: file_extension = uploaded_file.name.split('.')[-1] if file_extension in ['mp4', 'mov', 'avi', 'flv', 'mkv']: media_type = 'video' elif file_extension in ['mp3', 'wav', 'aac', 'ogg']: media_type = 'audio' else: media_type = 'unknown' media_loader = UploadedMediaLoader(uploaded_file, uploaded_file.name, media_type=media_type) elif youtube_url: media_loader = YouTubeLoader(youtube_url, output_path_video) similarity_top_k = st.number_input("Maximum Number of Results to Display", min_value=1, max_value=100, value=10) # Selecting the provider chosen_provider = st.selectbox("Choose Provider", ["OpenAI", "Replicate", "Default"]) # Based on provider, display relevant LLMs if chosen_provider == "OpenAI": llm_options = ["gpt-3.5-turbo-0301", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k-0314"] elif chosen_provider == "Replicate": llm_options = ["mistralai/mistral-7b-v0.1:3e8a0fb6d7812ce30701ba597e5080689bef8a013e5c6a724fafb108cc2426a0", "mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70", "joehoover/zephyr-7b-alpha:14ec63365a1141134c41b652fe798633f48b1fd28b356725c4d8842a0ac151ee", "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", "meta/llama-2-7b-chat:8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e", ] else: llm_options = ["default"] # Allow users to type a custom LLM or choose from list chosen_LLM = st.selectbox("Type or Choose Language Model", llm_options) api_key = st.text_input("OpenAI API Key", type="password") if api_key and chosen_provider == "OpenAI": logger.info("OpenAI API KEY") try: openai.api_key = api_key except: st.sidebar.write("Incorrect API key provided") elif api_key and chosen_provider == "Replicate": logger.info("Replicate API KEY") os.environ['REPLICATE_API_TOKEN'] = api_key else: chosen_LLM = "default" chosen_provider = "Default" if youtube_url or uploaded_file: video_file_path = os.path.join(output_path_video, f"{media_loader.media_id}.mp3") transcription_file_path = os.path.join(output_path_transcription, f"{media_loader.media_id}.json") if not check_file_exists(video_file_path): download_video(media_loader) else: logger.info(f"Video already downloaded: {video_file_path}") if not check_file_exists(transcription_file_path): transcribe_video(media_loader, output_path_video, output_path_transcription) else: logger.info(f"Transcription already exists: {transcription_file_path}") video_retriever = factory_media(media_loader.media_id, top_k=int(similarity_top_k)) transcript_summary = factory_transcript(media_loader.media_id, model=chosen_LLM, llm_provider=chosen_provider) docs = load_transcription(media_loader, output_path_transcription) col2, col3 = st.columns([3, 1]) # Main Content - Middle Section video_slot = col2.empty() with col2: if isinstance(media_loader, UploadedMediaLoader): video_slot.video(uploaded_file) elif isinstance(media_loader, YouTubeLoader): video_slot.video(youtube_url) st.title("Summary") # Display summary here st.write(transcript_summary.get_document_summary()) # Initialize session_state for chat history if it doesn't exist if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Main Content - Bottom Section for Chat st.title("Ask me") with col3: user_input = st.text_input("Search:") if user_input: if isinstance(media_loader, UploadedMediaLoader): video_slot.video(uploaded_file) elif isinstance(media_loader, YouTubeLoader): video_slot.video(youtube_url) raw_results = video_retriever.search(user_input) for i, result in enumerate(raw_results): text_content = result.node.text start_time = int(result.node.metadata['start']) full_youtube_url = f"{youtube_url}&t={start_time}s" if st.button(text_content, key=f"button_{i}"): st.session_state.current_video = full_youtube_url if isinstance(media_loader, UploadedMediaLoader): video_slot.video(uploaded_file, start_time=start_time) elif isinstance(media_loader, YouTubeLoader): video_slot.video(youtube_url, start_time=start_time) with col2: chat_placeholder = st.empty() def on_btn_click(): del st.session_state.past[:] del st.session_state.generated[:] def on_input_change(): user_input = st.session_state.user_input st.session_state.past.append(user_input) # Generate response only for the latest input latest_response = generate_response(st.session_state['past'][-1]) st.session_state.generated.append(latest_response) st.session_state.user_input = "" # This will empty the "User Input:" text box if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] with chat_placeholder.container(): for i in range(len(st.session_state['generated'])): message(st.session_state['past'][i], is_user=True, key=f"{i}_user") # Displaying generated message message(st.session_state['generated'][i], key=f"{i}", allow_html=True, is_table=False) st.button("Clear message", on_click=on_btn_click) with st.container(): st.text_input("User Input:", on_change=on_input_change, key="user_input")