Spaces:
Runtime error
Runtime error
| import os | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| import re | |
| import requests | |
| from embedding_processor import SentenceTransformerRetriever, process_data | |
| import pickle | |
| import logging | |
| import sys | |
| from llama_cpp import Llama | |
| from tqdm import tqdm | |
| # At the top of your script | |
| os.environ['LLAMA_CPP_THREADS'] = '4' | |
| os.environ['LLAMA_CPP_BATCH_SIZE'] = '512' | |
| os.environ['LLAMA_CPP_MODEL_PATH'] = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf") | |
| # Set page config first | |
| st.set_page_config( | |
| page_title="The Sport Chatbot", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| # Add this at the top level of your script, after imports | |
| def get_llama_model(): | |
| model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf") | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| if not os.path.exists(model_path): | |
| st.info("Downloading model... This may take a while.") | |
| direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
| download_file_with_progress(direct_url, model_path) | |
| llm_config = { | |
| "model_path": model_path, | |
| "n_ctx": 2048, | |
| "n_threads": 4, | |
| "n_batch": 512, | |
| "n_gpu_layers": 0, | |
| "verbose": False, | |
| "use_mlock": True | |
| } | |
| return Llama(**llm_config) | |
| def download_file_with_progress(url: str, filename: str): | |
| """Download a file with progress bar using requests""" | |
| response = requests.get(url, stream=True) | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(filename, 'wb') as file, tqdm( | |
| desc=filename, | |
| total=total_size, | |
| unit='iB', | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as progress_bar: | |
| for data in response.iter_content(chunk_size=1024): | |
| size = file.write(data) | |
| progress_bar.update(size) | |
| def load_from_drive(file_id: str): | |
| """Load pickle file directly from Google Drive""" | |
| try: | |
| url = f"https://drive.google.com/uc?id={file_id}&export=download" | |
| session = requests.Session() | |
| response = session.get(url, stream=True) | |
| for key, value in response.cookies.items(): | |
| if key.startswith('download_warning'): | |
| url = f"{url}&confirm={value}" | |
| response = session.get(url, stream=True) | |
| break | |
| content = response.content | |
| print(f"Successfully downloaded {len(content)} bytes") | |
| return pickle.loads(content) | |
| except Exception as e: | |
| print(f"Detailed error: {str(e)}") | |
| st.error(f"Error loading file from Drive: {str(e)}") | |
| return None | |
| # @st.cache_resource(show_spinner=False) | |
| # def load_llama_model(): | |
| # """Load Llama model with caching""" | |
| # try: | |
| # model_path = "mistral-7b-v0.1.Q4_K_M.gguf" | |
| # if not os.path.exists(model_path): | |
| # st.info("Downloading model... This may take a while.") | |
| # direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
| # download_file_with_progress(direct_url, model_path) | |
| # llm_config = { | |
| # "model_path": model_path, | |
| # "n_ctx": 2048, | |
| # "n_threads": 4, | |
| # "n_batch": 512, | |
| # "n_gpu_layers": 0, | |
| # "verbose": False | |
| # } | |
| # model = Llama(**llm_config) | |
| # st.success("Model loaded successfully!") | |
| # return model | |
| # except Exception as e: | |
| # st.error(f"Error loading model: {str(e)}") | |
| # raise | |
| def load_llama_model(): | |
| """Load Llama model with caching""" | |
| try: | |
| model_path = "mistral-7b-v0.1.Q4_K_M.gguf" | |
| if not os.path.exists(model_path): | |
| st.info("Downloading model... This may take a while.") | |
| direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
| download_file_with_progress(direct_url, model_path) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError("Model file not found after download attempt") | |
| if os.path.getsize(model_path) < 1000000: # Less than 1MB | |
| raise ValueError("Model file is too small, likely corrupted") | |
| llm_config = { | |
| "model_path": model_path, | |
| "n_ctx": 2048, | |
| "n_threads": 4, | |
| "n_batch": 512, | |
| "n_gpu_layers": 0, | |
| "verbose": True # Enable verbose mode for debugging | |
| } | |
| logging.info("Initializing Llama model...") | |
| model = Llama(**llm_config) | |
| # Test the model | |
| logging.info("Testing model...") | |
| test_response = model("Test", max_tokens=10) | |
| if not test_response: | |
| raise RuntimeError("Model test failed") | |
| logging.info("Model loaded and tested successfully") | |
| st.success("Model loaded successfully!") | |
| return model | |
| except Exception as e: | |
| logging.error(f"Error loading model: {str(e)}") | |
| logging.error("Full error details: ", exc_info=True) | |
| raise | |
| def check_environment(): | |
| """Check if the environment is properly set up""" | |
| try: | |
| import torch | |
| import sentence_transformers | |
| return True | |
| except ImportError as e: | |
| st.error(f"Missing required package: {str(e)}") | |
| st.stop() | |
| return False | |
| class RAGPipeline: | |
| def __init__(self, data_folder: str, k: int = 5): | |
| self.data_folder = data_folder | |
| self.k = k | |
| self.retriever = SentenceTransformerRetriever() | |
| self.documents = [] | |
| self.device = torch.device("cpu") | |
| # Use the cached model directly | |
| self.llm = get_llama_model() | |
| def preprocess_query(self, query: str) -> str: | |
| """Clean and prepare the query""" | |
| query = query.lower().strip() | |
| query = re.sub(r'\s+', ' ', query) | |
| return query | |
| ### Added on Nov 2, 2024 | |
| # def postprocess_response(self, response: str) -> str: | |
| # """Clean up the generated response""" | |
| # response = response.strip() | |
| # response = re.sub(r'\s+', ' ', response) | |
| # response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response) | |
| # return response | |
| # def query_model(self, prompt: str) -> str: | |
| # """Query the local Llama model""" | |
| # try: | |
| # if self.llm is None: | |
| # raise RuntimeError("Model not initialized") | |
| # response = self.llm( | |
| # prompt, | |
| # max_tokens=512, | |
| # temperature=0.4, | |
| # top_p=0.95, | |
| # echo=False, | |
| # stop=["Question:", "\n\n"] | |
| # ) | |
| # if response and 'choices' in response and len(response['choices']) > 0: | |
| # text = response['choices'][0].get('text', '').strip() | |
| # return text | |
| # else: | |
| # raise ValueError("No valid response generated") | |
| # except Exception as e: | |
| # logging.error(f"Error in query_model: {str(e)}") | |
| # raise | |
| # def process_query(self, query: str, placeholder) -> str: | |
| # try: | |
| # # Preprocess query | |
| # query = self.preprocess_query(query) | |
| # # Show retrieval status | |
| # status = placeholder.empty() | |
| # status.write("π Finding relevant information...") | |
| # # Get embeddings and search | |
| # query_embedding = self.retriever.encode([query]) | |
| # similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings) | |
| # scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents))) | |
| # relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
| # # Update status | |
| # status.write("π Generating response...") | |
| # # Prepare context and prompt | |
| # context = "\n".join(relevant_docs[:3]) | |
| # prompt = f"""Context information is below: | |
| # {context} | |
| # Given the context above, please answer the following question: | |
| # {query} | |
| # Guidelines: | |
| # - If you cannot answer based on the context, say so politely | |
| # - Keep the response concise and focused | |
| # - Only include sports-related information | |
| # - No dates or timestamps in the response | |
| # - Use clear, natural language | |
| # Answer:""" | |
| # # Generate response | |
| # response_placeholder = placeholder.empty() | |
| # try: | |
| # response_text = self.query_model(prompt) | |
| # if response_text: | |
| # final_response = self.postprocess_response(response_text) | |
| # response_placeholder.markdown(final_response) | |
| # return final_response | |
| # else: | |
| # message = "No relevant answer found. Please try rephrasing your question." | |
| # response_placeholder.warning(message) | |
| # return message | |
| # except Exception as e: | |
| # logging.error(f"Generation error: {str(e)}") | |
| # message = "Had some trouble generating the response. Please try again." | |
| # response_placeholder.warning(message) | |
| # return message | |
| # except Exception as e: | |
| # logging.error(f"Process error: {str(e)}") | |
| # message = "Something went wrong. Please try again with a different question." | |
| # placeholder.warning(message) | |
| # return message | |
| # def process_query(self, query: str, placeholder) -> str: | |
| # try: | |
| # # Preprocess query | |
| # query = self.preprocess_query(query) | |
| # logging.info(f"Processing query: {query}") | |
| # # Show retrieval status | |
| # status = placeholder.empty() | |
| # status.write("π Finding relevant information...") | |
| # # Get embeddings and search | |
| # query_embedding = self.retriever.encode([query]) | |
| # similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings) | |
| # scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents))) | |
| # # Log similarity scores | |
| # for idx, score in zip(indices.tolist(), scores.tolist()): | |
| # logging.info(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...") | |
| # relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
| # # Update status | |
| # status.write("π Generating response...") | |
| # # Prepare context and prompt | |
| # context = "\n".join(relevant_docs[:3]) | |
| # prompt = f"""Context information is below: | |
| # {context} | |
| # Given the context above, please answer the following question: | |
| # {query} | |
| # Guidelines: | |
| # - If you cannot answer based on the context, say so politely | |
| # - Keep the response concise and focused | |
| # - Only include sports-related information | |
| # - No dates or timestamps in the response | |
| # - Use clear, natural language | |
| # Answer:""" | |
| # # Generate response | |
| # response_placeholder = placeholder.empty() | |
| # try: | |
| # # Add logging for model state | |
| # logging.info("Model state check - Is None?: " + str(self.llm is None)) | |
| # # Directly use Llama model | |
| # response = self.llm( | |
| # prompt, | |
| # max_tokens=512, | |
| # temperature=0.4, | |
| # top_p=0.95, | |
| # echo=False, | |
| # stop=["Question:", "\n\n"] | |
| # ) | |
| # logging.info(f"Raw model response: {response}") | |
| # if response and isinstance(response, dict) and 'choices' in response: | |
| # generated_text = response['choices'][0].get('text', '').strip() | |
| # if generated_text: | |
| # final_response = self.postprocess_response(generated_text) | |
| # response_placeholder.markdown(final_response) | |
| # return final_response | |
| # message = "No relevant answer found. Please try rephrasing your question." | |
| # response_placeholder.warning(message) | |
| # return message | |
| # except Exception as e: | |
| # logging.error(f"Generation error: {str(e)}") | |
| # logging.error(f"Full error details: ", exc_info=True) | |
| # message = f"Had some trouble generating the response: {str(e)}" | |
| # response_placeholder.warning(message) | |
| # return message | |
| # except Exception as e: | |
| # logging.error(f"Process error: {str(e)}") | |
| # logging.error(f"Full error details: ", exc_info=True) | |
| # message = f"Something went wrong: {str(e)}" | |
| # placeholder.warning(message) | |
| # return message | |
| ### Added on Nov 2, 2024 | |
| def postprocess_response(self, response: str) -> str: | |
| """Clean up the generated response""" | |
| try: | |
| # Remove datetime patterns and other unwanted content | |
| response = re.sub(r'\d{4}-\d{2}-\d{2}(?:T|\s)\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?', '', response) | |
| response = re.sub(r'User \d+:.*?(?=User \d+:|$)', '', response) | |
| response = re.sub(r'\d{2}:\d{2}(?::\d{2})?(?:\s?(?:AM|PM))?', '', response) | |
| response = re.sub(r'\d{1,2}[-/]\d{1,2}[-/]\d{2,4}', '', response) | |
| response = re.sub(r'(?m)^User \d+:', '', response) | |
| # Clean up spacing but preserve intentional paragraph breaks | |
| # Replace multiple newlines with two newlines (one paragraph break) | |
| response = re.sub(r'\n\s*\n\s*\n+', '\n\n', response) | |
| # Replace multiple spaces with single space | |
| response = re.sub(r' +', ' ', response) | |
| # Clean up beginning/end | |
| response = response.strip() | |
| return response | |
| except Exception as e: | |
| logging.error(f"Error in postprocess_response: {str(e)}") | |
| return response | |
| def process_query(self, query: str, placeholder) -> str: | |
| try: | |
| # Verify this is the current query being processed | |
| if hasattr(st.session_state, 'current_query') and query != st.session_state.current_query: | |
| logging.warning(f"Skipping outdated query: {query}") | |
| return "" | |
| query = self.preprocess_query(query) | |
| status = placeholder.empty() | |
| status.write("π Finding relevant information...") | |
| query_embedding = self.retriever.encode([query]) | |
| similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings) | |
| scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents))) | |
| relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
| cleaned_docs = [] | |
| for doc in relevant_docs[:3]: | |
| cleaned_text = self.postprocess_response(doc) | |
| if cleaned_text: | |
| cleaned_docs.append(cleaned_text) | |
| status.write("π Generating response...") | |
| prompt = f"""Context information is below: | |
| {' '.join(cleaned_docs)} | |
| Given the context above, please answer the following question: | |
| {query} | |
| Guidelines for your response: | |
| - Structure your response in clear, logical paragraphs | |
| - Start a new paragraph for each new main point or aspect | |
| - If listing multiple items, use separate paragraphs | |
| - Keep each paragraph focused on a single topic or point | |
| - Use natural paragraph breaks where the content shifts focus | |
| - Maintain clear transitions between paragraphs | |
| - If providing statistics or achievements, group them logically | |
| - If describing different aspects (e.g., career, playing style, achievements), use separate paragraphs | |
| - Keep paragraphs concise but complete | |
| - Exclude any dates, timestamps, or user comments | |
| - Focus on factual sports information | |
| - If you cannot answer based on the context, say so politely | |
| Format your response with proper paragraph breaks where appropriate. | |
| Answer:""" | |
| response_placeholder = placeholder.empty() | |
| try: | |
| response_text = self.query_model(prompt) | |
| if response_text: | |
| # Clean up the response while preserving paragraph structure | |
| final_response = self.postprocess_response(response_text) | |
| # Convert cleaned response to markdown with proper paragraph spacing | |
| markdown_response = final_response.replace('\n\n', '\n\n \n\n') | |
| response_placeholder.markdown(markdown_response) | |
| return final_response | |
| else: | |
| message = "No relevant answer found. Please try rephrasing your question." | |
| response_placeholder.warning(message) | |
| return message | |
| except Exception as e: | |
| logging.error(f"Generation error: {str(e)}") | |
| message = "Had some trouble generating the response. Please try again." | |
| response_placeholder.warning(message) | |
| return message | |
| except Exception as e: | |
| logging.error(f"Process error: {str(e)}") | |
| message = "Something went wrong. Please try again with a different question." | |
| placeholder.warning(message) | |
| return message | |
| # def query_model(self, prompt: str) -> str: | |
| # """Query the local Llama model""" | |
| # try: | |
| # if self.llm is None: | |
| # raise RuntimeError("Model not initialized") | |
| # response = self.llm( | |
| # prompt, | |
| # max_tokens=512, | |
| # temperature=0.4, | |
| # top_p=0.95, | |
| # echo=False, | |
| # stop=["Question:", "Context:", "Guidelines:"], # Removed \n\n from stop tokens to allow paragraphs | |
| # repeat_penalty=1.1 # Added to encourage more diverse text | |
| # ) | |
| # if response and 'choices' in response and len(response['choices']) > 0: | |
| # text = response['choices'][0].get('text', '').strip() | |
| # return text | |
| # else: | |
| # raise ValueError("No valid response generated") | |
| # except Exception as e: | |
| # logging.error(f"Error in query_model: {str(e)}") | |
| # raise | |
| def query_model(self, prompt: str) -> str: | |
| """Query the local Llama model""" | |
| try: | |
| if self.llm is None: | |
| raise RuntimeError("Model not initialized") | |
| # Log the prompt for debugging | |
| logging.info(f"Sending prompt to model...") | |
| # Generate response with more explicit parameters | |
| response = self.llm( | |
| prompt, | |
| max_tokens=512, # Maximum length of the response | |
| temperature=0.7, # Slightly increased for more dynamic responses | |
| top_p=0.95, # Nucleus sampling parameter | |
| top_k=50, # Top-k sampling parameter | |
| echo=False, # Don't include prompt in response | |
| stop=["Question:", "Context:", "Guidelines:"], # Stop tokens | |
| repeat_penalty=1.1, # Penalize repetition | |
| presence_penalty=0.5, # Encourage topic diversity | |
| frequency_penalty=0.5 # Discourage word repetition | |
| ) | |
| # Log the raw response for debugging | |
| logging.info(f"Raw model response: {response}") | |
| if response and isinstance(response, dict) and 'choices' in response and response['choices']: | |
| generated_text = response['choices'][0].get('text', '').strip() | |
| if generated_text: | |
| logging.info(f"Generated text: {generated_text[:100]}...") # Log first 100 chars | |
| return generated_text | |
| else: | |
| logging.warning("Model returned empty response") | |
| raise ValueError("Empty response from model") | |
| else: | |
| logging.warning(f"Unexpected response format: {response}") | |
| raise ValueError("Invalid response format from model") | |
| except Exception as e: | |
| logging.error(f"Error in query_model: {str(e)}") | |
| logging.error("Full error details: ", exc_info=True) | |
| raise | |
| def initialize_model(self): | |
| """Initialize the model with proper error handling and verification""" | |
| try: | |
| if not os.path.exists(self.model_path): | |
| st.info("Downloading model... This may take a while.") | |
| direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
| download_file_with_progress(direct_url, self.model_path) | |
| # Verify file exists and has content | |
| if not os.path.exists(self.model_path): | |
| raise FileNotFoundError(f"Model file {self.model_path} not found after download attempts") | |
| if os.path.getsize(self.model_path) < 1000000: # Less than 1MB | |
| os.remove(self.model_path) | |
| raise ValueError("Downloaded model file is too small, likely corrupted") | |
| # Updated model configuration | |
| llm_config = { | |
| "model_path": self.model_path, | |
| "n_ctx": 4096, # Increased context window | |
| "n_threads": 4, | |
| "n_batch": 512, | |
| "n_gpu_layers": 0, | |
| "verbose": True, # Enable verbose mode for debugging | |
| "use_mlock": False, # Disable memory locking | |
| "last_n_tokens_size": 64, # Token window size for repeat penalty | |
| "seed": -1 # Random seed for reproducibility | |
| } | |
| logging.info("Initializing Llama model...") | |
| self.llm = Llama(**llm_config) | |
| # Test the model | |
| test_response = self.llm( | |
| "Test response", | |
| max_tokens=10, | |
| temperature=0.7, | |
| echo=False | |
| ) | |
| if not test_response or 'choices' not in test_response: | |
| raise RuntimeError("Model initialization test failed") | |
| logging.info("Model initialized and tested successfully") | |
| return self.llm | |
| except Exception as e: | |
| logging.error(f"Error initializing model: {str(e)}") | |
| raise | |
| # @st.cache_resource(show_spinner=False) | |
| # def initialize_rag_pipeline(): | |
| # """Initialize the RAG pipeline once""" | |
| # try: | |
| # # Create necessary directories | |
| # os.makedirs("ESPN_data", exist_ok=True) | |
| # # Load embeddings from Drive | |
| # drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN" | |
| # with st.spinner("Loading embeddings from Google Drive..."): | |
| # cache_data = load_from_drive(drive_file_id) | |
| # if cache_data is None: | |
| # st.error("Failed to load embeddings from Google Drive") | |
| # st.stop() | |
| # # Initialize pipeline | |
| # data_folder = "ESPN_data" | |
| # rag = RAGPipeline(data_folder) | |
| # # Store embeddings | |
| # rag.documents = cache_data['documents'] | |
| # rag.retriever.store_embeddings(cache_data['embeddings']) | |
| # return rag | |
| # except Exception as e: | |
| # logging.error(f"Pipeline initialization error: {str(e)}") | |
| # st.error(f"Failed to initialize the system: {str(e)}") | |
| # raise | |
| def initialize_rag_pipeline(): | |
| """Initialize the RAG pipeline once""" | |
| try: | |
| data_folder = "ESPN_data" | |
| if not os.path.exists(data_folder): | |
| os.makedirs(data_folder, exist_ok=True) | |
| # Load embeddings first | |
| drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN" | |
| with st.spinner("Loading data..."): | |
| cache_data = load_from_drive(drive_file_id) | |
| if cache_data is None: | |
| st.error("Failed to load embeddings from Google Drive") | |
| st.stop() | |
| # Initialize pipeline | |
| rag = RAGPipeline(data_folder) | |
| # Store embeddings | |
| rag.documents = cache_data['documents'] | |
| rag.retriever.store_embeddings(cache_data['embeddings']) | |
| return rag | |
| except Exception as e: | |
| logging.error(f"Pipeline initialization error: {str(e)}") | |
| st.error(f"Failed to initialize the system: {str(e)}") | |
| raise | |
| # def main(): | |
| # try: | |
| # # Environment check | |
| # if not check_environment(): | |
| # return | |
| # # Improved CSS styling | |
| # st.markdown(""" | |
| # <style> | |
| # /* Container styling */ | |
| # .block-container { | |
| # padding-top: 2rem; | |
| # padding-bottom: 2rem; | |
| # } | |
| # /* Text input styling */ | |
| # .stTextInput > div > div > input { | |
| # width: 100%; | |
| # } | |
| # /* Button styling */ | |
| # .stButton > button { | |
| # width: 200px; | |
| # margin: 0 auto; | |
| # display: block; | |
| # background-color: #FF4B4B; | |
| # color: white; | |
| # border-radius: 5px; | |
| # padding: 0.5rem 1rem; | |
| # } | |
| # /* Title styling */ | |
| # .main-title { | |
| # text-align: center; | |
| # padding: 1rem 0; | |
| # font-size: 3rem; | |
| # color: #1F1F1F; | |
| # } | |
| # .sub-title { | |
| # text-align: center; | |
| # padding: 0.5rem 0; | |
| # font-size: 1.5rem; | |
| # color: #4F4F4F; | |
| # } | |
| # /* Description styling */ | |
| # .description { | |
| # text-align: center; | |
| # color: #666666; | |
| # padding: 0.5rem 0; | |
| # font-size: 1.1rem; | |
| # line-height: 1.6; | |
| # margin-bottom: 1rem; | |
| # } | |
| # /* Answer container styling */ | |
| # .stMarkdown { | |
| # max-width: 100%; | |
| # } | |
| # /* Streamlit default overrides */ | |
| # .st-emotion-cache-16idsys p { | |
| # font-size: 1.1rem; | |
| # line-height: 1.6; | |
| # } | |
| # /* Container for main content */ | |
| # .main-content { | |
| # max-width: 1200px; | |
| # margin: 0 auto; | |
| # padding: 0 1rem; | |
| # } | |
| # </style> | |
| # """, unsafe_allow_html=True) | |
| # # Header section | |
| # st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
| # st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
| # st.markdown(""" | |
| # <p class='description'> | |
| # Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
| # With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
| # </p> | |
| # <p class='description'> | |
| # Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
| # </p> | |
| # """, unsafe_allow_html=True) | |
| # # Initialize the pipeline | |
| # if 'rag' not in st.session_state: | |
| # with st.spinner("Loading resources..."): | |
| # st.session_state.rag = initialize_rag_pipeline() | |
| # # Create columns for layout | |
| # col1, col2, col3 = st.columns([1, 6, 1]) | |
| # with col2: | |
| # # Query input | |
| # query = st.text_input("What would you like to know about sports?") | |
| # if st.button("Get Answer"): | |
| # if query: | |
| # response_placeholder = st.empty() | |
| # try: | |
| # response = st.session_state.rag.process_query(query, response_placeholder) | |
| # logging.info(f"Generated response: {response}") | |
| # except Exception as e: | |
| # logging.error(f"Query processing error: {str(e)}") | |
| # response_placeholder.warning("Unable to process your question. Please try again.") | |
| # else: | |
| # st.warning("Please enter a question!") | |
| # # Footer | |
| # st.markdown("<br><br>", unsafe_allow_html=True) | |
| # st.markdown("---") | |
| # st.markdown(""" | |
| # <p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
| # Powered by ESPN Data & Mistral AI π | |
| # </p> | |
| # """, unsafe_allow_html=True) | |
| # except Exception as e: | |
| # logging.error(f"Application error: {str(e)}") | |
| # st.error("An unexpected error occurred. Please check the logs and try again.") | |
| # def main(): | |
| # try: | |
| # # Environment check | |
| # if not check_environment(): | |
| # return | |
| # # Improved CSS styling | |
| # st.markdown(""" | |
| # <style> | |
| # /* Container styling */ | |
| # .block-container { | |
| # padding-top: 2rem; | |
| # padding-bottom: 2rem; | |
| # } | |
| # /* Text input styling */ | |
| # .stTextInput > div > div > input { | |
| # width: 100%; | |
| # } | |
| # /* Button styling */ | |
| # .stButton > button { | |
| # width: 200px; | |
| # margin: 0 auto; | |
| # display: block; | |
| # background-color: #FF4B4B; | |
| # color: white; | |
| # border-radius: 5px; | |
| # padding: 0.5rem 1rem; | |
| # } | |
| # /* Title styling */ | |
| # .main-title { | |
| # text-align: center; | |
| # padding: 1rem 0; | |
| # font-size: 3rem; | |
| # color: #1F1F1F; | |
| # } | |
| # .sub-title { | |
| # text-align: center; | |
| # padding: 0.5rem 0; | |
| # font-size: 1.5rem; | |
| # color: #4F4F4F; | |
| # } | |
| # /* Description styling */ | |
| # .description { | |
| # text-align: center; | |
| # color: #666666; | |
| # padding: 0.5rem 0; | |
| # font-size: 1.1rem; | |
| # line-height: 1.6; | |
| # margin-bottom: 1rem; | |
| # } | |
| # /* Answer container styling */ | |
| # .stMarkdown { | |
| # max-width: 100%; | |
| # } | |
| # /* Streamlit default overrides */ | |
| # .st-emotion-cache-16idsys p { | |
| # font-size: 1.1rem; | |
| # line-height: 1.6; | |
| # } | |
| # /* Container for main content */ | |
| # .main-content { | |
| # max-width: 1200px; | |
| # margin: 0 auto; | |
| # padding: 0 1rem; | |
| # } | |
| # </style> | |
| # """, unsafe_allow_html=True) | |
| # # Header section | |
| # st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
| # st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
| # st.markdown(""" | |
| # <p class='description'> | |
| # Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
| # With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
| # </p> | |
| # <p class='description'> | |
| # Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
| # </p> | |
| # """, unsafe_allow_html=True) | |
| # # Initialize the pipeline with better error handling | |
| # if 'rag' not in st.session_state: | |
| # try: | |
| # with st.spinner("Loading resources..."): | |
| # st.session_state.rag = initialize_rag_pipeline() | |
| # logging.info("Pipeline initialized successfully") | |
| # except Exception as e: | |
| # logging.error(f"Pipeline initialization error: {str(e)}") | |
| # st.error("Failed to initialize the system. Please check the logs.") | |
| # st.stop() | |
| # return | |
| # # Create columns for layout | |
| # col1, col2, col3 = st.columns([1, 6, 1]) | |
| # with col2: | |
| # # Query input | |
| # query = st.text_input("What would you like to know about sports?") | |
| # if st.button("Get Answer"): | |
| # if query: | |
| # response_placeholder = st.empty() | |
| # try: | |
| # # Log query processing start | |
| # logging.info(f"Processing query: {query}") | |
| # # Process query and get response | |
| # response = st.session_state.rag.process_query(query, response_placeholder) | |
| # # Log successful response | |
| # logging.info(f"Generated response: {response}") | |
| # except Exception as e: | |
| # # Log error details | |
| # logging.error(f"Query processing error: {str(e)}") | |
| # logging.error("Full error details: ", exc_info=True) | |
| # response_placeholder.warning("Unable to process your question. Please try again.") | |
| # else: | |
| # st.warning("Please enter a question!") | |
| # # Footer | |
| # st.markdown("<br><br>", unsafe_allow_html=True) | |
| # st.markdown("---") | |
| # st.markdown(""" | |
| # <p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
| # Powered by ESPN Data & Mistral AI π | |
| # </p> | |
| # """, unsafe_allow_html=True) | |
| # except Exception as e: | |
| # logging.error(f"Application error: {str(e)}") | |
| # logging.error("Full error details: ", exc_info=True) | |
| # st.error("An unexpected error occurred. Please check the logs and try again.") | |
| # if __name__ == "__main__": | |
| # # Configure logging | |
| # logging.basicConfig( | |
| # level=logging.INFO, | |
| # format='%(asctime)s - %(levelname)s - %(message)s' | |
| # ) | |
| # try: | |
| # main() | |
| # except Exception as e: | |
| # logging.error(f"Fatal error: {str(e)}") | |
| # logging.error("Full error details: ", exc_info=True) | |
| # st.error("A fatal error occurred. Please check the logs and try again.") | |
| # if __name__ == "__main__": | |
| # main() | |
| def main(): | |
| try: | |
| # First, check if model exists | |
| model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf") | |
| if not os.path.exists(model_path): | |
| st.warning("β οΈ First-time setup: The model will be downloaded. This takes a few minutes but only happens once.") | |
| # Environment check | |
| if not check_environment(): | |
| return | |
| # Initialize session state variables | |
| if 'current_query' not in st.session_state: | |
| st.session_state.current_query = None | |
| if 'processing' not in st.session_state: | |
| st.session_state.processing = False | |
| # Improved CSS styling | |
| st.markdown(""" | |
| <style> | |
| /* Container styling */ | |
| .block-container { | |
| padding-top: 2rem; | |
| padding-bottom: 2rem; | |
| } | |
| /* Text input styling */ | |
| .stTextInput > div > div > input { | |
| width: 100%; | |
| } | |
| /* Button styling */ | |
| .stButton > button { | |
| width: 200px; | |
| margin: 0 auto; | |
| display: block; | |
| background-color: #FF4B4B; | |
| color: white; | |
| border-radius: 5px; | |
| padding: 0.5rem 1rem; | |
| } | |
| /* Title styling */ | |
| .main-title { | |
| text-align: center; | |
| padding: 1rem 0; | |
| font-size: 3rem; | |
| color: #1F1F1F; | |
| } | |
| .sub-title { | |
| text-align: center; | |
| padding: 0.5rem 0; | |
| font-size: 1.5rem; | |
| color: #4F4F4F; | |
| } | |
| /* Description styling */ | |
| .description { | |
| text-align: center; | |
| color: #666666; | |
| padding: 0.5rem 0; | |
| font-size: 1.1rem; | |
| line-height: 1.6; | |
| margin-bottom: 1rem; | |
| } | |
| /* Answer container styling */ | |
| .stMarkdown { | |
| max-width: 100%; | |
| } | |
| /* Streamlit default overrides */ | |
| .st-emotion-cache-16idsys p { | |
| font-size: 1.1rem; | |
| line-height: 1.6; | |
| } | |
| /* Container for main content */ | |
| .main-content { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 0 1rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Header section | |
| st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
| st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
| st.markdown(""" | |
| <p class='description'> | |
| Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
| With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
| </p> | |
| <p class='description'> | |
| Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
| </p> | |
| """, unsafe_allow_html=True) | |
| # Initialize the pipeline | |
| if 'rag' not in st.session_state: | |
| try: | |
| with st.spinner("Loading resources..."): | |
| st.session_state.rag = initialize_rag_pipeline() | |
| logging.info("Pipeline initialized successfully") | |
| except Exception as e: | |
| logging.error(f"Pipeline initialization error: {str(e)}") | |
| st.error("Failed to initialize the system. Please check the logs.") | |
| st.stop() | |
| return | |
| # Create columns for layout | |
| col1, col2, col3 = st.columns([1, 6, 1]) | |
| with col2: | |
| # Query input with unique key | |
| query = st.text_input( | |
| "What would you like to know about sports?", | |
| key="sports_query" | |
| ) | |
| # Centered button with unique key | |
| if st.button("Get Answer", key="answer_button"): | |
| if query: | |
| # Clear any previous response | |
| if 'response_placeholder' in st.session_state: | |
| st.session_state.response_placeholder.empty() | |
| response_placeholder = st.empty() | |
| st.session_state.response_placeholder = response_placeholder | |
| try: | |
| # Update current query and processing state | |
| st.session_state.current_query = query | |
| st.session_state.processing = True | |
| # Log query processing start | |
| logging.info(f"Processing query: {query}") | |
| with st.spinner("Processing your question..."): | |
| # Process query and get response | |
| response = st.session_state.rag.process_query(query, response_placeholder) | |
| # Log successful response | |
| logging.info(f"Generated response: {response}") | |
| # Reset processing state | |
| st.session_state.processing = False | |
| except Exception as e: | |
| # Log error details | |
| logging.error(f"Query processing error: {str(e)}") | |
| logging.error("Full error details: ", exc_info=True) | |
| response_placeholder.warning("Unable to process your question. Please try again.") | |
| st.session_state.processing = False | |
| else: | |
| st.warning("Please enter a question!") | |
| # Footer | |
| st.markdown("<br><br>", unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.markdown(""" | |
| <p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
| Powered by ESPN Data & Mistral AI π | |
| </p> | |
| """, unsafe_allow_html=True) | |
| except Exception as e: | |
| logging.error(f"Application error: {str(e)}") | |
| logging.error("Full error details: ", exc_info=True) | |
| st.error("An unexpected error occurred. Please check the logs and try again.") | |
| if __name__ == "__main__": | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| try: | |
| main() | |
| except Exception as e: | |
| logging.error(f"Fatal error: {str(e)}") | |
| logging.error("Full error details: ", exc_info=True) | |
| st.error("A fatal error occurred. Please check the logs and try again.") |