import os # Ensure SAMBANOVA_BASE_URL is in the environment for litellm # This should be set before dynamic_cheatsheet.language_model is imported if it relies on it at import time, # but it's generally used at runtime when making the API call. # Setting it here early in app.py is a safeguard. SAMBANOVA_DEFINED_BASE_URL = "https://api.sambanova.ai/v1" if "SAMBANOVA_BASE_URL" not in os.environ: os.environ["SAMBANOVA_BASE_URL"] = SAMBANOVA_DEFINED_BASE_URL print(f"SAMBANOVA_BASE_URL environment variable set to: {SAMBANOVA_DEFINED_BASE_URL}") elif os.environ["SAMBANOVA_BASE_URL"] != SAMBANOVA_DEFINED_BASE_URL: print(f"Warning: SAMBANOVA_BASE_URL environment variable is already set to {os.environ['SAMBANOVA_BASE_URL']}, but app expects {SAMBANOVA_DEFINED_BASE_URL}. Using the existing one.") import gradio as gr import sys # Add the project root to the Python path to allow importing dynamic_cheatsheet sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "."))) from dynamic_cheatsheet.language_model import LanguageModel # --- Configuration --- SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY") # SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present # SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B" GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt" CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt" GENERATOR_PROMPT = "" CURATOR_PROMPT = "" try: with open(GENERATOR_PROMPT_PATH, "r") as f: GENERATOR_PROMPT = f.read() with open(CURATOR_PROMPT_PATH, "r") as f: CURATOR_PROMPT = f.read() except FileNotFoundError: print(f"Error: Prompt files not found at {GENERATOR_PROMPT_PATH} or {CURATOR_PROMPT_PATH}. Please ensure they exist.") GENERATOR_PROMPT = "You are a helpful assistant. Given a question and a cheatsheet, provide an answer. Cheatsheet: [[CHEATSHEET]] Question: [[QUESTION]] FINAL ANSWER: " CURATOR_PROMPT = "You are a helpful assistant. Given a question, a model answer, and a previous cheatsheet, update the cheatsheet. Previous Cheatsheet: [[PREVIOUS_CHEATSHEET]] Question: [[QUESTION]] Model Answer: [[MODEL_ANSWER]] NEW CHEATSHEET: " # --- Global variable for cheatsheet --- current_cheatsheet_cache = "(empty)" def initialize_model(model_name_input): if not SAMBANOVA_API_KEY: raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.") # LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm model = LanguageModel( model_name=model_name_input ) return model def generate_cheatsheet_func(training_data_text, model_name_input, progress=gr.Progress(track_tqdm=True)): global current_cheatsheet_cache if not training_data_text.strip(): current_cheatsheet_cache = "(empty)" return "Training data is empty. Cheatsheet reset to (empty)." print('generate_cheatsheet_func model_name_input', model_name_input) model = initialize_model(model_name_input) training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()] cheatsheet_content = "(empty)" progress(0, desc="Initializing Cheatsheet Generation") for i, example_input in enumerate(progress.tqdm(training_examples, desc="Generating Cheatsheet")): print(f"Processing training example {i+1}/{len(training_examples)}: {example_input[:50]}...") try: results_dict = model.advanced_generate( approach_name="DynamicCheatsheet_Cumulative", input_txt=example_input, cheatsheet=cheatsheet_content, generator_template=GENERATOR_PROMPT, cheatsheet_template=CURATOR_PROMPT, temperature=0.1, max_tokens=2048 ) cheatsheet_content = results_dict.get("final_cheatsheet", cheatsheet_content) except Exception as e: print(f"Error processing example '{example_input[:50]}...': {e}") # Continue with the current cheatsheet, and show error in UI gr.Warning(f"Error on example '{example_input[:30]}...': {e}. Skipping this example.") pass current_cheatsheet_cache = cheatsheet_content return current_cheatsheet_cache def get_answers_func(user_query, model_name_input): global current_cheatsheet_cache if not user_query.strip(): return "Query is empty.", "Query is empty." print('get_answers_func model_name_input', model_name_input) model = initialize_model(model_name_input) answer_with_cheatsheet = "Error retrieving answer." answer_without_cheatsheet = "Error retrieving answer." # Inference WITH cheatsheet try: print(f"Querying WITH cheatsheet ({current_cheatsheet_cache[:50]}...)") results_with_cheatsheet = model.advanced_generate( approach_name="DynamicCheatsheet_Cumulative", input_txt=user_query, cheatsheet=current_cheatsheet_cache, generator_template=GENERATOR_PROMPT, cheatsheet_template=CURATOR_PROMPT, temperature=0.1, max_tokens=2048 ) answer_with_cheatsheet = results_with_cheatsheet.get("final_answer", "Error: Could not extract answer.") except Exception as e: print(f"Error (with cheatsheet): {e}") answer_with_cheatsheet = f"Error during inference with cheatsheet: {e}" # Inference WITHOUT cheatsheet try: print(f"Querying WITHOUT cheatsheet...") results_without_cheatsheet = model.advanced_generate( approach_name="DynamicCheatsheet_Cumulative", input_txt=user_query, cheatsheet="(empty)", generator_template=GENERATOR_PROMPT, cheatsheet_template=CURATOR_PROMPT, temperature=0.1, max_tokens=2048 ) answer_without_cheatsheet = results_without_cheatsheet.get("final_answer", "Error: Could not extract answer.") except Exception as e: print(f"Error (without cheatsheet): {e}") answer_without_cheatsheet = f"Error during inference without cheatsheet: {e}" return answer_with_cheatsheet, answer_without_cheatsheet # --- Gradio Interface --- with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo: gr.Markdown("# Task Caching Demo") gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.") training_data_example = ''' {"benchmark": "FinanceBench", "question": "What was 3M's net sales in 2018?", "answer": "3M's net sales in 2018 were $32,765 million.", "evidence_text": "Net sales of $32,765 million, up 3.5 percent from $31,657 million in 2017."} {"benchmark": "FinanceBench", "question": "What was the total amount of cash and cash equivalents for 3M at the end of 2018?", "answer": "At the end of 2018, 3M had $3,567 million in cash and cash equivalents.", "evidence_text": "Cash and cash equivalents totaled $3,567 million at December 31, 2018 compared to $3,291 million at December 31, 2017."} {"benchmark": "FinanceBench", "question": "What is 3M's strategy regarding acquisitions and divestitures?", "answer": "3M's strategy includes pursuing acquisitions and divestitures to strengthen its portfolio and create shareholder value. This involves identifying and acquiring businesses that align with its strategic priorities and divesting non-core assets.", "evidence_text": "Our strategy includes the active management of our portfolio through acquisitions and divestitures. We seek to acquire businesses that align with our strategic priorities and offer attractive returns, and we may divest businesses that are no longer core to our strategy or do not meet our performance expectations."} ''' with gr.Tabs(): model_name_input = gr.Textbox( label="SambaNova Model Name", # value="sambanova/Meta-Llama-3.1-8B-Instruct", # Default value value="sambanova/DeepSeek-R1", # Default value info="Enter the SambaNova model name (e.g., sambanova/DeepSeek-R1-Distill-Llama-70B). Ensure the 'sambanova/' prefix if required by litellm configuration." ) SAMBANOVA_API_KEY = gr.Textbox( label="SambaNova API Key", value="", # Default value info="Please Enter your SambaNova API Key, otherwise by default will use Changran's key, but RPM is low" ) with gr.Tabs(): with gr.TabItem("1. Task Caching (Generate Task-Specific Cheatsheet from Training Data)"): gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.") training_data_input = gr.Textbox(lines=10, label="Training Data", value=training_data_example) generate_cheatsheet_button = gr.Button("Generate Cheatsheet (Task Caching)", variant="primary") cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True) generate_cheatsheet_button.click( fn=generate_cheatsheet_func, inputs=[training_data_input, model_name_input], outputs=cheatsheet_output, show_progress="full" ) with gr.TabItem("2. Test Inference"): gr.Markdown("Enter your query below. The model will attempt to answer it twice: once using the generated cheatsheet (if any), and once without it.") query_input = gr.Textbox(lines=3, label="Your Query", value="e.g., What is the solution to 5 6 6 8 in the Game of 24?") get_answers_button = gr.Button("Get Answers", variant="primary") with gr.Row(): answer_with_cheatsheet_output = gr.Textbox(label="Answer WITH Task Caching", lines=10, interactive=False, show_label=True) answer_without_cheatsheet_output = gr.Textbox(label="Answer WITHOUT Task Caching", lines=10, interactive=False, show_label=True) get_answers_button.click( fn=get_answers_func, inputs=[query_input, model_name_input], outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output] ) gr.Markdown("**Important:** Ensure `SAMBANOVA_API_KEY` is set as a secret in your Hugging Face Space or as an environment variable if running locally. `SAMBANOVA_BASE_URL` is set to `https://api.sambanova.ai/v1` by default if not found in environment.") if __name__ == "__main__": if not SAMBANOVA_API_KEY: print("Warning: SAMBANOVA_API_KEY is not set. The application will likely fail to contact the SambaNova API.") print("Please set the SAMBANOVA_API_KEY environment variable.") demo.launch()