task-caching-v1 / app.py
ChangranHuuu's picture
Update app.py
ad30f0a verified
import os
# Ensure SAMBANOVA_BASE_URL is in the environment for litellm
# This should be set before dynamic_cheatsheet.language_model is imported if it relies on it at import time,
# but it's generally used at runtime when making the API call.
# Setting it here early in app.py is a safeguard.
SAMBANOVA_DEFINED_BASE_URL = "https://api.sambanova.ai/v1"
if "SAMBANOVA_BASE_URL" not in os.environ:
os.environ["SAMBANOVA_BASE_URL"] = SAMBANOVA_DEFINED_BASE_URL
print(f"SAMBANOVA_BASE_URL environment variable set to: {SAMBANOVA_DEFINED_BASE_URL}")
elif os.environ["SAMBANOVA_BASE_URL"] != SAMBANOVA_DEFINED_BASE_URL:
print(f"Warning: SAMBANOVA_BASE_URL environment variable is already set to {os.environ['SAMBANOVA_BASE_URL']}, but app expects {SAMBANOVA_DEFINED_BASE_URL}. Using the existing one.")
import gradio as gr
import sys
# Add the project root to the Python path to allow importing dynamic_cheatsheet
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".")))
from dynamic_cheatsheet.language_model import LanguageModel
# --- Configuration ---
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
# SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present
# SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B"
GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt"
CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt"
GENERATOR_PROMPT = ""
CURATOR_PROMPT = ""
try:
with open(GENERATOR_PROMPT_PATH, "r") as f:
GENERATOR_PROMPT = f.read()
with open(CURATOR_PROMPT_PATH, "r") as f:
CURATOR_PROMPT = f.read()
except FileNotFoundError:
print(f"Error: Prompt files not found at {GENERATOR_PROMPT_PATH} or {CURATOR_PROMPT_PATH}. Please ensure they exist.")
GENERATOR_PROMPT = "You are a helpful assistant. Given a question and a cheatsheet, provide an answer. Cheatsheet: [[CHEATSHEET]] Question: [[QUESTION]] FINAL ANSWER: <answer></answer>"
CURATOR_PROMPT = "You are a helpful assistant. Given a question, a model answer, and a previous cheatsheet, update the cheatsheet. Previous Cheatsheet: [[PREVIOUS_CHEATSHEET]] Question: [[QUESTION]] Model Answer: [[MODEL_ANSWER]] NEW CHEATSHEET: <cheatsheet></cheatsheet>"
# --- Global variable for cheatsheet ---
current_cheatsheet_cache = "(empty)"
def initialize_model(model_name_input):
if not SAMBANOVA_API_KEY:
raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.")
# LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm
model = LanguageModel(
model_name=model_name_input
)
return model
def generate_cheatsheet_func(training_data_text, model_name_input, progress=gr.Progress(track_tqdm=True)):
global current_cheatsheet_cache
if not training_data_text.strip():
current_cheatsheet_cache = "(empty)"
return "Training data is empty. Cheatsheet reset to (empty)."
print('generate_cheatsheet_func model_name_input', model_name_input)
model = initialize_model(model_name_input)
training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()]
cheatsheet_content = "(empty)"
progress(0, desc="Initializing Cheatsheet Generation")
for i, example_input in enumerate(progress.tqdm(training_examples, desc="Generating Cheatsheet")):
print(f"Processing training example {i+1}/{len(training_examples)}: {example_input[:50]}...")
try:
results_dict = model.advanced_generate(
approach_name="DynamicCheatsheet_Cumulative",
input_txt=example_input,
cheatsheet=cheatsheet_content,
generator_template=GENERATOR_PROMPT,
cheatsheet_template=CURATOR_PROMPT,
temperature=0.1,
max_tokens=2048
)
cheatsheet_content = results_dict.get("final_cheatsheet", cheatsheet_content)
except Exception as e:
print(f"Error processing example '{example_input[:50]}...': {e}")
# Continue with the current cheatsheet, and show error in UI
gr.Warning(f"Error on example '{example_input[:30]}...': {e}. Skipping this example.")
pass
current_cheatsheet_cache = cheatsheet_content
return current_cheatsheet_cache
def get_answers_func(user_query, model_name_input):
global current_cheatsheet_cache
if not user_query.strip():
return "Query is empty.", "Query is empty."
print('get_answers_func model_name_input', model_name_input)
model = initialize_model(model_name_input)
answer_with_cheatsheet = "Error retrieving answer."
answer_without_cheatsheet = "Error retrieving answer."
# Inference WITH cheatsheet
try:
print(f"Querying WITH cheatsheet ({current_cheatsheet_cache[:50]}...)")
results_with_cheatsheet = model.advanced_generate(
approach_name="DynamicCheatsheet_Cumulative",
input_txt=user_query,
cheatsheet=current_cheatsheet_cache,
generator_template=GENERATOR_PROMPT,
cheatsheet_template=CURATOR_PROMPT,
temperature=0.1,
max_tokens=2048
)
answer_with_cheatsheet = results_with_cheatsheet.get("final_answer", "Error: Could not extract answer.")
except Exception as e:
print(f"Error (with cheatsheet): {e}")
answer_with_cheatsheet = f"Error during inference with cheatsheet: {e}"
# Inference WITHOUT cheatsheet
try:
print(f"Querying WITHOUT cheatsheet...")
results_without_cheatsheet = model.advanced_generate(
approach_name="DynamicCheatsheet_Cumulative",
input_txt=user_query,
cheatsheet="(empty)",
generator_template=GENERATOR_PROMPT,
cheatsheet_template=CURATOR_PROMPT,
temperature=0.1,
max_tokens=2048
)
answer_without_cheatsheet = results_without_cheatsheet.get("final_answer", "Error: Could not extract answer.")
except Exception as e:
print(f"Error (without cheatsheet): {e}")
answer_without_cheatsheet = f"Error during inference without cheatsheet: {e}"
return answer_with_cheatsheet, answer_without_cheatsheet
# --- Gradio Interface ---
with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
gr.Markdown("# Task Caching Demo")
gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.")
training_data_example = '''
Solve for 24: 1 2 3 4
Solve for 24: 3 4 5 6
Solve for 24: 4 5 6 7
'''
with gr.Tabs():
model_name_input = gr.Textbox(
label="SambaNova Model Name",
value="sambanova/Meta-Llama-3.1-8B-Instruct", # Default value
info="Enter the SambaNova model name (e.g., sambanova/DeepSeek-R1-Distill-Llama-70B). Ensure the 'sambanova/' prefix if required by litellm configuration."
)
SAMBANOVA_API_KEY = gr.Textbox(
label="SambaNova API Key",
value="", # Default value
info="Please Enter your SambaNova API Key, otherwise by default will use Changran's key, but RPM is low"
)
with gr.Tabs():
with gr.TabItem("1. Task Caching (Generate Task-Specific Cheatsheet from Training Data)"):
gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.")
training_data_input = gr.Textbox(lines=10, label="Training Data", value=training_data_example)
generate_cheatsheet_button = gr.Button("Generate Cheatsheet (Task Caching)", variant="primary")
cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True)
generate_cheatsheet_button.click(
fn=generate_cheatsheet_func,
inputs=[training_data_input, model_name_input],
outputs=cheatsheet_output,
show_progress="full"
)
with gr.TabItem("2. Test Inference"):
gr.Markdown("Enter your query below. The model will attempt to answer it twice: once using the generated cheatsheet (if any), and once without it.")
query_input = gr.Textbox(lines=3, label="Your Query", value="e.g., What is the solution to 5 6 6 8 in the Game of 24?")
get_answers_button = gr.Button("Get Answers", variant="primary")
with gr.Row():
answer_with_cheatsheet_output = gr.Textbox(label="Answer WITH Task Caching", lines=10, interactive=False, show_label=True)
answer_without_cheatsheet_output = gr.Textbox(label="Answer WITHOUT Task Caching", lines=10, interactive=False, show_label=True)
get_answers_button.click(
fn=get_answers_func,
inputs=[query_input, model_name_input],
outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output]
)
gr.Markdown("**Important:** Ensure `SAMBANOVA_API_KEY` is set as a secret in your Hugging Face Space or as an environment variable if running locally. `SAMBANOVA_BASE_URL` is set to `https://api.sambanova.ai/v1` by default if not found in environment.")
if __name__ == "__main__":
if not SAMBANOVA_API_KEY:
print("Warning: SAMBANOVA_API_KEY is not set. The application will likely fail to contact the SambaNova API.")
print("Please set the SAMBANOVA_API_KEY environment variable.")
demo.launch()