gabrielaltay commited on
Commit
8a92b0a
1 Parent(s): ee3f9c2

nvidia nim update

Browse files
Files changed (2) hide show
  1. app.py +30 -7
  2. requirements.txt +9 -8
app.py CHANGED
@@ -28,6 +28,7 @@ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
28
  from langchain_community.vectorstores.utils import DistanceStrategy
29
  from langchain_openai import ChatOpenAI
30
  from langchain_anthropic import ChatAnthropic
 
31
  from langchain_pinecone import PineconeVectorStore
32
  from pinecone import Pinecone
33
  import streamlit as st
@@ -57,7 +58,6 @@ CONGRESS_GOV_TYPE_MAP = {
57
  }
58
  OPENAI_CHAT_MODELS = [
59
  "gpt-3.5-turbo-0125",
60
- # "gpt-4-0125-preview",
61
  "gpt-4o",
62
  ]
63
  ANTHROPIC_CHAT_MODELS = [
@@ -65,7 +65,20 @@ ANTHROPIC_CHAT_MODELS = [
65
  "claude-3-sonnet-20240229",
66
  "claude-3-haiku-20240307",
67
  ]
68
- CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  def get_sponsor_url(bioguide_id: str) -> str:
@@ -278,9 +291,10 @@ def render_sidebar():
278
  st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
279
 
280
  with st.expander("Generative Config"):
281
- st.selectbox(label="model name", options=CHAT_MODELS, key="model_name")
 
282
  st.slider(
283
- "temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
284
  )
285
  st.slider(
286
  "max_output_tokens", min_value=512, max_value=1024, key="max_output_tokens"
@@ -315,7 +329,7 @@ def render_query_rag_tab():
315
 
316
  render_example_queries()
317
 
318
- QUERY_TEMPLATE = """Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
319
 
320
  ---
321
 
@@ -329,7 +343,6 @@ Query: {query}"""
329
 
330
  prompt = ChatPromptTemplate.from_messages(
331
  [
332
- ("system", "You are an expert legislative analyst."),
333
  ("human", QUERY_TEMPLATE),
334
  ]
335
  )
@@ -398,7 +411,8 @@ def render_query_agent_tab():
398
 
399
  from langchain_community.tools import WikipediaQueryRun
400
  from langchain_community.utilities import WikipediaAPIWrapper
401
- from langchain.agents import load_tools
 
402
  from langchain.agents import create_react_agent
403
  from langchain import hub
404
 
@@ -497,6 +511,15 @@ elif SS["model_name"] in ANTHROPIC_CHAT_MODELS:
497
  top_p=SS["top_p"],
498
  max_tokens_to_sample=SS["max_output_tokens"],
499
  )
 
 
 
 
 
 
 
 
 
500
  else:
501
  raise ValueError()
502
 
 
28
  from langchain_community.vectorstores.utils import DistanceStrategy
29
  from langchain_openai import ChatOpenAI
30
  from langchain_anthropic import ChatAnthropic
31
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
32
  from langchain_pinecone import PineconeVectorStore
33
  from pinecone import Pinecone
34
  import streamlit as st
 
58
  }
59
  OPENAI_CHAT_MODELS = [
60
  "gpt-3.5-turbo-0125",
 
61
  "gpt-4o",
62
  ]
63
  ANTHROPIC_CHAT_MODELS = [
 
65
  "claude-3-sonnet-20240229",
66
  "claude-3-haiku-20240307",
67
  ]
68
+ NVIDIA_NIM_CHAT_MODELS = [
69
+ "microsoft/phi-3-mini-128k-instruct",
70
+ "google/gemma-7b",
71
+ "meta/llama3-8b-instruct",
72
+ "meta/llama3-70b-instruct",
73
+ "mistralai/mixtral-8x22b-instruct-v0.1",
74
+ ]
75
+ CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS + NVIDIA_NIM_CHAT_MODELS
76
+
77
+ PROVIDER_MODELS = {
78
+ "OpenAI": OPENAI_CHAT_MODELS,
79
+ "Anthropic": ANTHROPIC_CHAT_MODELS,
80
+ "Nvidia NIM": NVIDIA_NIM_CHAT_MODELS,
81
+ }
82
 
83
 
84
  def get_sponsor_url(bioguide_id: str) -> str:
 
291
  st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
292
 
293
  with st.expander("Generative Config"):
294
+ st.selectbox(label="provider", options=PROVIDER_MODELS.keys(), key="provider")
295
+ st.selectbox(label="model name", options=PROVIDER_MODELS[SS["provider"]], key="model_name")
296
  st.slider(
297
+ "temperature", min_value=0.0, max_value=2.0, value=0.01, key="temperature"
298
  )
299
  st.slider(
300
  "max_output_tokens", min_value=512, max_value=1024, key="max_output_tokens"
 
329
 
330
  render_example_queries()
331
 
332
+ QUERY_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
333
 
334
  ---
335
 
 
343
 
344
  prompt = ChatPromptTemplate.from_messages(
345
  [
 
346
  ("human", QUERY_TEMPLATE),
347
  ]
348
  )
 
411
 
412
  from langchain_community.tools import WikipediaQueryRun
413
  from langchain_community.utilities import WikipediaAPIWrapper
414
+ # from langchain.agents import load_tools
415
+ from langchain_community.agent_toolkits.load_tools import load_tools
416
  from langchain.agents import create_react_agent
417
  from langchain import hub
418
 
 
511
  top_p=SS["top_p"],
512
  max_tokens_to_sample=SS["max_output_tokens"],
513
  )
514
+ elif SS["model_name"] in NVIDIA_NIM_CHAT_MODELS:
515
+ llm = ChatNVIDIA(
516
+ model=SS["model_name"],
517
+ temperature=SS["temperature"],
518
+ max_tokens=SS["max_output_tokens"],
519
+ top_p=SS["top_p"],
520
+ seed=SEED,
521
+ nvidia_api_key=st.secrets["nvidia_api_key"],
522
+ )
523
  else:
524
  raise ValueError()
525
 
requirements.txt CHANGED
@@ -41,15 +41,16 @@ jsonpatch==1.33
41
  jsonpointer==2.4
42
  jsonschema==4.21.1
43
  jsonschema-specifications==2023.12.1
44
- langchain==0.1.13
45
  langchain-anthropic==0.1.1
46
- langchain-community==0.0.29
47
- langchain-core==0.1.36
48
- langchain-openai==0.0.7
 
49
  langchain-pinecone==0.0.3
50
- langchain-text-splitters==0.0.1
51
  langchainhub==0.1.15
52
- langsmith==0.1.38
53
  markdown-it-py==3.0.0
54
  MarkupSafe==2.1.5
55
  marshmallow==3.20.2
@@ -60,7 +61,7 @@ multidict==6.0.5
60
  mypy-extensions==1.0.0
61
  networkx==3.2.1
62
  numpy==1.26.4
63
- openai==1.12.0
64
  orjson==3.10.0
65
  packaging==23.2
66
  pandas==2.2.1
@@ -102,7 +103,7 @@ streamlit==1.31.1
102
  sympy==1.12
103
  tenacity==8.2.3
104
  threadpoolctl==3.3.0
105
- tiktoken==0.6.0
106
  tokenizers==0.15.2
107
  toml==0.10.2
108
  tomli==2.0.1
 
41
  jsonpointer==2.4
42
  jsonschema==4.21.1
43
  jsonschema-specifications==2023.12.1
44
+ langchain==0.2.5
45
  langchain-anthropic==0.1.1
46
+ langchain-community==0.2.5
47
+ langchain-core==0.2.7
48
+ langchain-nvidia-ai-endpoints==0.1.2
49
+ langchain-openai==0.1.8
50
  langchain-pinecone==0.0.3
51
+ langchain-text-splitters==0.2.1
52
  langchainhub==0.1.15
53
+ langsmith==0.1.77
54
  markdown-it-py==3.0.0
55
  MarkupSafe==2.1.5
56
  marshmallow==3.20.2
 
61
  mypy-extensions==1.0.0
62
  networkx==3.2.1
63
  numpy==1.26.4
64
+ openai==1.34.0
65
  orjson==3.10.0
66
  packaging==23.2
67
  pandas==2.2.1
 
103
  sympy==1.12
104
  tenacity==8.2.3
105
  threadpoolctl==3.3.0
106
+ tiktoken==0.7.0
107
  tokenizers==0.15.2
108
  toml==0.10.2
109
  tomli==2.0.1