ChangranHuuu commited on
Commit
05ea985
·
verified ·
1 Parent(s): c095b91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -21,8 +21,8 @@ from dynamic_cheatsheet.language_model import LanguageModel
21
  # --- Configuration ---
22
  SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
23
  # SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present
24
- # SAMBANOVA_MODEL_NAME = "samba/DeepSeek-R1-Distill-Llama-70B" # Using litellm convention for SambaNova
25
- SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B"
26
 
27
  GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt"
28
  CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt"
@@ -43,22 +43,22 @@ except FileNotFoundError:
43
  # --- Global variable for cheatsheet ---
44
  current_cheatsheet_cache = "(empty)"
45
 
46
- def initialize_model():
47
  if not SAMBANOVA_API_KEY:
48
  raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.")
49
  # LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm
50
  model = LanguageModel(
51
- model_name=SAMBANOVA_MODEL_NAME
52
  )
53
  return model
54
 
55
- def generate_cheatsheet_func(training_data_text, progress=gr.Progress(track_tqdm=True)):
56
  global current_cheatsheet_cache
57
  if not training_data_text.strip():
58
  current_cheatsheet_cache = "(empty)"
59
  return "Training data is empty. Cheatsheet reset to (empty)."
60
 
61
- model = initialize_model()
62
 
63
  training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()]
64
 
@@ -86,12 +86,12 @@ def generate_cheatsheet_func(training_data_text, progress=gr.Progress(track_tqdm
86
  current_cheatsheet_cache = cheatsheet_content
87
  return current_cheatsheet_cache
88
 
89
- def get_answers_func(user_query):
90
  global current_cheatsheet_cache
91
  if not user_query.strip():
92
  return "Query is empty.", "Query is empty."
93
 
94
- model = initialize_model()
95
  answer_with_cheatsheet = "Error retrieving answer."
96
  answer_without_cheatsheet = "Error retrieving answer."
97
 
@@ -136,6 +136,13 @@ with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Soft()) as demo:
136
  gr.Markdown("# Task Caching Demo")
137
  gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.")
138
 
 
 
 
 
 
 
 
139
  with gr.Tabs():
140
  with gr.TabItem("1. Generate Cheatsheet (Task Caching)"):
141
  gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.")
@@ -144,7 +151,7 @@ with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Soft()) as demo:
144
  cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True)
145
  generate_cheatsheet_button.click(
146
  fn=generate_cheatsheet_func,
147
- inputs=training_data_input,
148
  outputs=cheatsheet_output,
149
  show_progress="full"
150
  )
@@ -160,7 +167,7 @@ with gr.Blocks(title="Task Caching Demo", theme=gr.themes.Soft()) as demo:
160
 
161
  get_answers_button.click(
162
  fn=get_answers_func,
163
- inputs=query_input,
164
  outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output]
165
  )
166
 
 
21
  # --- Configuration ---
22
  SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY")
23
  # SAMBANOVA_BASE_URL is now set from SAMBANOVA_DEFINED_BASE_URL to env var if not present
24
+
25
+ # SAMBANOVA_MODEL_NAME = "sambanova/DeepSeek-R1-Distill-Llama-70B"
26
 
27
  GENERATOR_PROMPT_PATH = "prompts/generator_prompt.txt"
28
  CURATOR_PROMPT_PATH = "prompts/curator_prompt_for_dc_cumulative.txt"
 
43
  # --- Global variable for cheatsheet ---
44
  current_cheatsheet_cache = "(empty)"
45
 
46
+ def initialize_model(model_name_input):
47
  if not SAMBANOVA_API_KEY:
48
  raise gr.Error("SAMBANOVA_API_KEY environment variable not set. Please set it in your Hugging Face Space secrets or local environment.")
49
  # LanguageModel will be modified to handle samba/ prefix using env vars for API key/base URL via litellm
50
  model = LanguageModel(
51
+ model_name=model_name_input
52
  )
53
  return model
54
 
55
+ def generate_cheatsheet_func(training_data_text, model_name_input, progress=gr.Progress(track_tqdm=True)):
56
  global current_cheatsheet_cache
57
  if not training_data_text.strip():
58
  current_cheatsheet_cache = "(empty)"
59
  return "Training data is empty. Cheatsheet reset to (empty)."
60
 
61
+ model = initialize_model(model_name_input)
62
 
63
  training_examples = [ex.strip() for ex in training_data_text.split("\n") if ex.strip()]
64
 
 
86
  current_cheatsheet_cache = cheatsheet_content
87
  return current_cheatsheet_cache
88
 
89
+ def get_answers_func(user_query, model_name_input):
90
  global current_cheatsheet_cache
91
  if not user_query.strip():
92
  return "Query is empty.", "Query is empty."
93
 
94
+ model = initialize_model(model_name_input)
95
  answer_with_cheatsheet = "Error retrieving answer."
96
  answer_without_cheatsheet = "Error retrieving answer."
97
 
 
136
  gr.Markdown("# Task Caching Demo")
137
  gr.Markdown("Demonstrates the effect of using a dynamically generated cheatsheet (Task Caching) on model inference. Uses SambaNova API via `litellm`.")
138
 
139
+ model_name_input = gr.Textbox(
140
+ label="SambaNova Model Name",
141
+ value="sambanova/DeepSeek-R1-Distill-Llama-70B", # Default value
142
+ info="Enter the SambaNova model name (e.g., samba/DeepSeek-R1-Distill-Llama-70B). Ensure the 'samba/' prefix if required by litellm configuration."
143
+ )
144
+ # END OF ADDED PART
145
+
146
  with gr.Tabs():
147
  with gr.TabItem("1. Generate Cheatsheet (Task Caching)"):
148
  gr.Markdown("Paste your training data below, one example per line. This data will be used to build a cumulative cheatsheet. The process may take some time depending on the number of examples.")
 
151
  cheatsheet_output = gr.Textbox(label="Generated Cheatsheet", lines=15, interactive=False, show_label=True)
152
  generate_cheatsheet_button.click(
153
  fn=generate_cheatsheet_func,
154
+ inputs=[training_data_input, model_name_input],
155
  outputs=cheatsheet_output,
156
  show_progress="full"
157
  )
 
167
 
168
  get_answers_button.click(
169
  fn=get_answers_func,
170
+ inputs=[query_input, model_name_input]
171
  outputs=[answer_with_cheatsheet_output, answer_without_cheatsheet_output]
172
  )
173