Tuana commited on
Commit
5afd6f9
β€’
1 Parent(s): f8c619e

option to switch models

Browse files
Files changed (3) hide show
  1. app.py +21 -4
  2. utils/haystack.py +14 -6
  3. utils/ui.py +27 -11
app.py CHANGED
@@ -15,10 +15,27 @@ set_initial_state()
15
  sidebar()
16
 
17
  st.write("# Get the summaries of latest top Hacker News posts 🧑")
 
 
18
 
19
- if st.session_state.get("HF_TGI_TOKEN"):
20
- pipeline = start_haystack(st.session_state.get("HF_TGI_TOKEN"))
21
- st.session_state["api_key_configured"] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  search_bar, button = st.columns(2)
23
  # Search bar
24
  with search_bar:
@@ -29,7 +46,7 @@ if st.session_state.get("HF_TGI_TOKEN"):
29
  st.write("")
30
  run_pressed = st.button("Get summaries")
31
  else:
32
- st.write("Please provide your Hugging Face Token to start using the application")
33
  st.write("If you are using a smaller screen, open the sidebar from the top left to provide your token πŸ™Œ")
34
 
35
  if st.session_state.get("api_key_configured"):
 
15
  sidebar()
16
 
17
  st.write("# Get the summaries of latest top Hacker News posts 🧑")
18
+ if st.session_state.get("model") == None:
19
+ mistral, openai = st.columns(2)
20
 
21
+ with mistral:
22
+ mistral_pressed = st.button("Mistral")
23
+ if mistral_pressed:
24
+ st.session_state["model"] = "Mistral"
25
+ with openai:
26
+ openai_pressed = st.button("OpenAI")
27
+ if openai_pressed:
28
+ st.session_state["model"] = "GPT-4"
29
+
30
+ if st.session_state.get("model") and (st.session_state.get("HF_TGI_TOKEN") or st.session_state.get("OPENAI_API_KEY")):
31
+ if st.session_state.get("HF_TGI_TOKEN"):
32
+ pipeline = start_haystack(st.session_state.get("HF_TGI_TOKEN"), st.session_state.get("model"))
33
+ st.session_state["api_key_configured"] = True
34
+
35
+ elif st.session_state.get("OPENAI_API_KEY"):
36
+ pipeline = start_haystack(st.session_state.get("OPENAI_API_KEY"), st.session_state.get("model"))
37
+ st.session_state["api_key_configured"] = True
38
+
39
  search_bar, button = st.columns(2)
40
  # Search bar
41
  with search_bar:
 
46
  st.write("")
47
  run_pressed = st.button("Get summaries")
48
  else:
49
+ st.write("Please provide your Hugging Face or OpenAI key to start using the application")
50
  st.write("If you are using a smaller screen, open the sidebar from the top left to provide your token πŸ™Œ")
51
 
52
  if st.session_state.get("api_key_configured"):
utils/haystack.py CHANGED
@@ -1,10 +1,10 @@
1
  import streamlit as st
2
  from haystack import Pipeline
3
  from haystack.components.builders.prompt_builder import PromptBuilder
4
- from haystack.components.generators import HuggingFaceTGIGenerator
5
  from .hackernews_fetcher import HackernewsFetcher
6
 
7
- def start_haystack(hf_token):
8
  prompt_template = """
9
  You will be provided one or more top HakcerNews posts, followed by their URL.
10
  For the posts you have, provide a short summary followed by the URL that the post can be found at.
@@ -18,7 +18,10 @@ Summaries:
18
  """
19
 
20
  prompt_builder = PromptBuilder(template=prompt_template)
21
- llm = HuggingFaceTGIGenerator("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
 
 
 
22
  fetcher = HackernewsFetcher()
23
 
24
  pipe = Pipeline()
@@ -34,9 +37,14 @@ Summaries:
34
  @st.cache_data(show_spinner=True)
35
  def query(top_k, _pipeline):
36
  try:
37
- replies = _pipeline.run(data={"hackernews_fetcher": {"top_k": top_k},
38
- "llm": {"generation_kwargs": {"max_new_tokens": 600}}
39
- })
 
 
 
 
 
40
 
41
  result = replies['llm']['replies']
42
  except Exception as e:
 
1
  import streamlit as st
2
  from haystack import Pipeline
3
  from haystack.components.builders.prompt_builder import PromptBuilder
4
+ from haystack.components.generators import HuggingFaceTGIGenerator, OpenAIGenerator
5
  from .hackernews_fetcher import HackernewsFetcher
6
 
7
+ def start_haystack(key, model):
8
  prompt_template = """
9
  You will be provided one or more top HakcerNews posts, followed by their URL.
10
  For the posts you have, provide a short summary followed by the URL that the post can be found at.
 
18
  """
19
 
20
  prompt_builder = PromptBuilder(template=prompt_template)
21
+ if model == "Mistral":
22
+ llm = HuggingFaceTGIGenerator("mistralai/Mixtral-8x7B-Instruct-v0.1", token=key)
23
+ elif model == "GPT-4":
24
+ llm = OpenAIGenerator(api_key=key, model="gpt-4")
25
  fetcher = HackernewsFetcher()
26
 
27
  pipe = Pipeline()
 
37
  @st.cache_data(show_spinner=True)
38
  def query(top_k, _pipeline):
39
  try:
40
+ run_args = {"hackernews_fetcher": {"top_k": top_k}}
41
+
42
+ if st.session_state.get("model") == "Mistral":
43
+ run_args = {"hackernews_fetcher": {"top_k": top_k},
44
+ "llm": {"generation_kwargs": {"max_new_tokens": 600}}
45
+ }
46
+
47
+ replies = _pipeline.run(data=run_args)
48
 
49
  result = replies['llm']['replies']
50
  except Exception as e:
utils/ui.py CHANGED
@@ -9,14 +9,18 @@ def set_initial_state():
9
  set_state_if_absent("top_k", "How many of the top posts would you like a summary for?")
10
  set_state_if_absent("result", None)
11
  set_state_if_absent("haystack_started", False)
 
12
 
13
  def reset_results(*args):
14
  st.session_state.result = None
15
  st.session_state.top_k = None
16
 
17
- def set_openai_api_key(api_key: str):
18
  st.session_state["HF_TGI_TOKEN"] = api_key
19
 
 
 
 
20
  def sidebar():
21
  with st.sidebar:
22
  # image = Image.open('logo/haystack-logo-colored.png')
@@ -33,16 +37,28 @@ def sidebar():
33
  "3. Enjoy πŸ€—\n"
34
  )
35
 
36
- api_key_input = st.text_input(
37
- "Hugging Face Token",
38
- type="password",
39
- placeholder="Paste your Hugging Face TGI Token",
40
- help="You can get your API key from https://platform.openai.com/account/api-keys.",
41
- value=st.session_state.get("HF_TGI_TOKEN", ""),
42
- )
43
-
44
- if api_key_input:
45
- set_openai_api_key(api_key_input)
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  st.markdown("---")
48
  st.markdown(
 
9
  set_state_if_absent("top_k", "How many of the top posts would you like a summary for?")
10
  set_state_if_absent("result", None)
11
  set_state_if_absent("haystack_started", False)
12
+ set_state_if_absent("model", None)
13
 
14
  def reset_results(*args):
15
  st.session_state.result = None
16
  st.session_state.top_k = None
17
 
18
+ def set_hf_token(api_key: str):
19
  st.session_state["HF_TGI_TOKEN"] = api_key
20
 
21
+ def set_openai_key(api_key: str):
22
+ st.session_state["OPENAI_API_KEY"] = api_key
23
+
24
  def sidebar():
25
  with st.sidebar:
26
  # image = Image.open('logo/haystack-logo-colored.png')
 
37
  "3. Enjoy πŸ€—\n"
38
  )
39
 
40
+ if st.session_state.model == "Mistral":
41
+ api_key_input = st.text_input(
42
+ "Hugging Face Token",
43
+ type="password",
44
+ placeholder="Paste your Hugging Face TGI Token",
45
+ help="You can get your API key from https://platform.openai.com/account/api-keys.",
46
+ value=st.session_state.get("HF_TGI_TOKEN", ""),
47
+ )
48
+ if api_key_input:
49
+ set_hf_token(api_key_input)
50
+
51
+ elif st.session_state.model == "GPT-4":
52
+ api_key_input = st.text_input(
53
+ "OpenAI API Key",
54
+ type="password",
55
+ placeholder="Paste your OpenAI API Key",
56
+ help="You can get your API key from https://platform.openai.com/account/api-keys.",
57
+ value=st.session_state.get("OPENAI_API_KEY", ""),
58
+ )
59
+ if api_key_input:
60
+ set_openai_key(api_key_input)
61
+
62
 
63
  st.markdown("---")
64
  st.markdown(