import os # Ensure SAMBANOVA_BASE_URL is in the environment for litellm # This should be set before dynamic_cheatsheet.language_model is imported if it relies on it at import time, # but it's generally used at runtime when making the API call. # Setting it here early in app.py is a safeguard. SAMBANOVA_DEFINED_BASE_URL = "https://api.sambanova.ai/v1" if "SAMBANOVA_BASE_URL" not in os.environ: os.environ["SAMBANOVA_BASE_URL"] = SAMBANOVA_DEFINED_BASE_URL print(f"SAMBANOVA_BASE_URL environment variable set to: {SAMBANOVA_DEFINED_BASE_URL}") elif os.environ["SAMBANOVA_BASE_URL"] != SAMBANOVA_DEFINED_BASE_URL: print(f"Warning: SAMBANOVA_BASE_URL environment variable is already set to {os.environ['SAMBANOVA_BASE_URL']}, but app expects {SAMBANOVA_DEFINED_BASE_URL}. Using the existing one.") import gradio as gr import sys # Add the project root to the Python path to allow importing dynamic_cheatsheet sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "."))) from dynamic_cheatsheet.language_model import LanguageModel # --- Configuration --- SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY") # SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present # SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B" GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt" CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt" GENERATOR_PROMPT = "" CURATOR_PROMPT = "" try: with open(GENERATOR_PROMPT_PATH, "r") as f: GENERATOR_PROMPT = f.read() with open(CURATOR_PROMPT_PATH, "r") as f: CURATOR_PROMPT = f.read() except FileNotFoundError: print(f"Error: Prompt files not found at {GENERATOR_PROMPT_PATH} or {CURATOR_PROMPT_PATH}. Please ensure they exist.") GENERATOR_PROMPT = "You are a helpful assistant. Given a question and a cheatsheet, provide an answer. Cheatsheet: [[CHEATSHEET]] Question: [[QUESTION]] FINAL ANSWER: " CURATOR_PROMPT = "You are a helpful assistant. Given a question, a model answer, and a previous cheatsheet, update the cheatsheet. Previous Cheatsheet: [[PREVIOUS_CHEATSHEET]] Question: [[QUESTION]] Model Answer: [[MODEL_ANSWER]] NEW CHEATSHEET: " # --- Global variable for cheatsheet --- current_cheatsheet_cache = "(empty)" def initialize_model(model_name_input): if not SAMBANOVA_API_KEY: raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.") # LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm model = LanguageModel( model_name=model_name_input ) return model def generate_cheatsheet_func(training_data_text, model_name_input, progress=gr.Progress(track_tqdm=True)): global current_cheatsheet_cache if not training_data_text.strip(): current_cheatsheet_cache = "(empty)" return "Training data is empty. Cheatsheet reset to (empty)." print('generate_cheatsheet_func model_name_input', model_name_input) model = initialize_model(model_name_input) training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()] cheatsheet_content = "(empty)" progress(0, desc="Initializing Cheatsheet Generation") for i, example_input in enumerate(progress.tqdm(training_examples, desc="Generating Cheatsheet")): print(f"Processing training example {i+1}/{len(training_examples)}: {example_input[:50]}...") try: results_dict = model.advanced_generate( approach_name="DynamicCheatsheet_Cumulative", input_txt=example_input, cheatsheet=cheatsheet_content, generator_template=GENERATOR_PROMPT, cheatsheet_template=CURATOR_PROMPT, temperature=0.1, max_tokens=2048 ) cheatsheet_content = results_dict.get("final_cheatsheet", cheatsheet_content) except Exception as e: print(f"Error processing example '{example_input[:50]}...': {e}") # Continue with the current cheatsheet, and show error in UI gr.Warning(f"Error on example '{example_input[:30]}...': {e}. Skipping this example.") pass current_cheatsheet_cache = cheatsheet_content return current_cheatsheet_cache def get_answers_func(user_query, model_name_input): global current_cheatsheet_cache if not user_query.strip(): return "Query is empty.", "Query is empty." print('get_answers_func model_name_input', model_name_input) model = initialize_model(model_name_input) answer_with_cheatsheet = "Error retrieving answer." answer_without_cheatsheet = "Error retrieving answer." # Inference WITH cheatsheet try: print(f"Querying WITH cheatsheet ({current_cheatsheet_cache[:50]}...)") results_with_cheatsheet = model.advanced_generate( approach_name="DynamicCheatsheet_Cumulative", input_txt=user_query, cheatsheet=current_cheatsheet_cache, generator_template=GENERATOR_PROMPT, cheatsheet_template=CURATOR_PROMPT, temperature=0.1, max_tokens=2048 ) answer_with_cheatsheet = results_with_cheatsheet.get("final_answer", "Error: Could not extract answer.") except Exception as e: print(f"Error (with cheatsheet): {e}") answer_with_cheatsheet = f"Error during inference with cheatsheet: {e}" # Inference WITHOUT cheatsheet try: print(f"Querying WITHOUT cheatsheet...") results_without_cheatsheet = model.advanced_generate( approach_name="DynamicCheatsheet_Cumulative", input_txt=user_query, cheatsheet="(empty)", generator_template=GENERATOR_PROMPT, cheatsheet_template=CURATOR_PROMPT, temperature=0.1, max_tokens=2048 ) answer_without_cheatsheet = results_without_cheatsheet.get("final_answer", "Error: Could not extract answer.") except Exception as e: print(f"Error (without cheatsheet): {e}") answer_without_cheatsheet = f"Error during inference without cheatsheet: {e}" return answer_with_cheatsheet, answer_without_cheatsheet # --- Gradio Interface --- with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo: gr.Markdown("# Task Caching Demo") gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.") training_data_example = ''' Solve for 24: 1 2 3 4 Solve for 24: 3 4 5 6 Solve for 24: 4 5 6 7 ''' with gr.Tabs(): model_name_input = gr.Textbox( label="SambaNova Model Name", value="sambanova/Meta-Llama-3.1-8B-Instruct", # Default value info="Enter the SambaNova model name (e.g., sambanova/DeepSeek-R1-Distill-Llama-70B). Ensure the 'sambanova/' prefix if required by litellm configuration." ) SAMBANOVA_API_KEY = gr.Textbox( label="SambaNova API Key", value="", # Default value info="Please Enter your SambaNova API Key, otherwise by default will use Changran's key, but RPM is low" ) with gr.Tabs(): with gr.TabItem("1. Task Caching (Generate Task-Specific Cheatsheet from Training Data)"): gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.") training_data_input = gr.Textbox(lines=10, label="Training Data", value=training_data_example) generate_cheatsheet_button = gr.Button("Generate Cheatsheet (Task Caching)", variant="primary") cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True) generate_cheatsheet_button.click( fn=generate_cheatsheet_func, inputs=[training_data_input, model_name_input], outputs=cheatsheet_output, show_progress="full" ) with gr.TabItem("2. Test Inference"): gr.Markdown("Enter your query below. The model will attempt to answer it twice: once using the generated cheatsheet (if any), and once without it.") query_input = gr.Textbox(lines=3, label="Your Query", value="e.g., What is the solution to 5 6 6 8 in the Game of 24?") get_answers_button = gr.Button("Get Answers", variant="primary") with gr.Row(): answer_with_cheatsheet_output = gr.Textbox(label="Answer WITH Task Caching", lines=10, interactive=False, show_label=True) answer_without_cheatsheet_output = gr.Textbox(label="Answer WITHOUT Task Caching", lines=10, interactive=False, show_label=True) get_answers_button.click( fn=get_answers_func, inputs=[query_input, model_name_input], outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output] ) gr.Markdown("**Important:** Ensure `SAMBANOVA_API_KEY` is set as a secret in your Hugging Face Space or as an environment variable if running locally. `SAMBANOVA_BASE_URL` is set to `https://api.sambanova.ai/v1` by default if not found in environment.") if __name__ == "__main__": if not SAMBANOVA_API_KEY: print("Warning: SAMBANOVA_API_KEY is not set. The application will likely fail to contact the SambaNova API.") print("Please set the SAMBANOVA_API_KEY environment variable.") demo.launch()