Ashmi Banerjee commited on
Commit
d5b3118
1 Parent(s): c8193d0

updated configs

Browse files
Files changed (3) hide show
  1. app.py +16 -4
  2. models/gemini.py +17 -5
  3. models/gemma.py +5 -2
app.py CHANGED
@@ -9,12 +9,17 @@ def clear():
9
  return None, None, None
10
 
11
 
12
- def generate_text(query_text, model_name: Optional[str] = "google/gemma-2b-it"):
 
13
  combined_information = get_context(query_text)
 
 
 
 
14
  if model_name is None or model_name == "google/gemma-2b-it":
15
- return gemma_predict(combined_information, model_name)
16
  if model_name == "gemini-1.0-pro":
17
- return get_gemini_response(combined_information, model_name, None)
18
  return "Sorry, something went wrong! Please try again."
19
 
20
 
@@ -42,6 +47,13 @@ with gr.Blocks() as demo:
42
  )
43
  output = gr.Textbox(label="Generated Results", lines=4)
44
 
 
 
 
 
 
 
 
45
  with gr.Group():
46
  with gr.Row():
47
  submit_btn = gr.Button("Submit", variant="primary")
@@ -54,7 +66,7 @@ with gr.Blocks() as demo:
54
  gr.Markdown("## Examples")
55
  gr.Examples(
56
  examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
57
- cache_examples=True,
58
  )
59
 
60
  if __name__ == "__main__":
 
9
  return None, None, None
10
 
11
 
12
+ def generate_text(query_text, model_name: Optional[str] = "google/gemma-2b-it", tokens: Optional[int] = 1024,
13
+ temp: Optional[float] = 0.49):
14
  combined_information = get_context(query_text)
15
+ gen_config = {
16
+ "temperature": temp,
17
+ "max_output_tokens": tokens,
18
+ }
19
  if model_name is None or model_name == "google/gemma-2b-it":
20
+ return gemma_predict(combined_information, model_name, gen_config)
21
  if model_name == "gemini-1.0-pro":
22
+ return get_gemini_response(combined_information, model_name, gen_config)
23
  return "Sorry, something went wrong! Please try again."
24
 
25
 
 
47
  )
48
  output = gr.Textbox(label="Generated Results", lines=4)
49
 
50
+ with gr.Accordion("Settings", open=False):
51
+ max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
52
+ interactive=True,
53
+ visible=True, info="The maximum number of output tokens")
54
+ temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
55
+ interactive=True,
56
+ visible=True, info="The value used to module the logits distribution")
57
  with gr.Group():
58
  with gr.Row():
59
  submit_btn = gr.Button("Submit", variant="primary")
 
66
  gr.Markdown("## Examples")
67
  gr.Examples(
68
  examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
69
+ # cache_examples=True,
70
  )
71
 
72
  if __name__ == "__main__":
models/gemini.py CHANGED
@@ -7,18 +7,30 @@ from dotenv import load_dotenv
7
 
8
  sys.path.append("../")
9
  from setup.vertex_ai_setup import initialize_vertexai_params
10
- from vertexai.generative_models import GenerativeModel
11
 
12
  load_dotenv()
13
  VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECT"]
14
 
 
 
 
 
 
 
 
 
 
15
 
16
- def get_gemini_response(prompt_text, model, parameters: Optional = None, location: Optional[str] = "us-central1") -> str:
17
- initialize_vertexai_params()
18
 
19
- if model is None or parameters is None:
 
 
 
20
  model = "gemini-1.0-pro"
21
- model = GenerativeModel(model)
 
 
22
 
23
  model_response = model.generate_content(prompt_text)
24
 
 
7
 
8
  sys.path.append("../")
9
  from setup.vertex_ai_setup import initialize_vertexai_params
10
+ from vertexai import generative_models
11
 
12
  load_dotenv()
13
  VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECT"]
14
 
15
+ DEFAULT_GEN_CONFIG = {
16
+ "temperature": 0.49,
17
+ "max_output_tokens": 1024,
18
+ }
19
+
20
+ DEFAULT_SAFETY_SETTINGS = {
21
+ generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
22
+ generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
23
+ }
24
 
 
 
25
 
26
+ def get_gemini_response(prompt_text, model, generation_config: Optional[dict] = None,
27
+ safety_settings: Optional[dict] = None) -> str:
28
+ initialize_vertexai_params()
29
+ if model is None:
30
  model = "gemini-1.0-pro"
31
+ model = generative_models.GenerativeModel(model,
32
+ generation_config=DEFAULT_GEN_CONFIG if generation_config is None else generation_config,
33
+ safety_settings=DEFAULT_SAFETY_SETTINGS if safety_settings is None else safety_settings)
34
 
35
  model_response = model.generate_content(prompt_text)
36
 
models/gemma.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  from dotenv import load_dotenv
4
  from huggingface_hub import InferenceClient
@@ -6,10 +7,12 @@ from huggingface_hub import InferenceClient
6
  load_dotenv()
7
 
8
 
9
- def gemma_predict(combined_information, model_name):
10
  HF_token = os.environ["HF_TOKEN"]
11
  client = InferenceClient(model_name, token=HF_token)
12
- stream = client.text_generation(prompt=combined_information, details=True, stream=True, max_new_tokens=2048,
 
 
13
  return_full_text=False)
14
  output = ""
15
 
 
1
  import os
2
+ from typing import Optional
3
 
4
  from dotenv import load_dotenv
5
  from huggingface_hub import InferenceClient
 
7
  load_dotenv()
8
 
9
 
10
+ def gemma_predict(combined_information, model_name, config: Optional[dict]):
11
  HF_token = os.environ["HF_TOKEN"]
12
  client = InferenceClient(model_name, token=HF_token)
13
+ stream = client.text_generation(prompt=combined_information, details=True, stream=True,
14
+ max_new_tokens=config["max_output_tokens"],
15
+ temperature=config["temperature"],
16
  return_full_text=False)
17
  output = ""
18