gabrielaltay commited on
Commit
c6e2641
1 Parent(s): df92eb1

add claude

Browse files
Files changed (2) hide show
  1. app.py +41 -13
  2. requirements.txt +21 -0
app.py CHANGED
@@ -17,6 +17,7 @@ from langchain_community.callbacks import get_openai_callback
17
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
18
  from langchain_community.vectorstores.utils import DistanceStrategy
19
  from langchain_openai import ChatOpenAI
 
20
  from langchain_pinecone import PineconeVectorStore
21
  from pinecone import Pinecone
22
  import streamlit as st
@@ -46,7 +47,12 @@ OPENAI_CHAT_MODELS = [
46
  "gpt-3.5-turbo-0125",
47
  "gpt-4-0125-preview",
48
  ]
49
-
 
 
 
 
 
50
 
51
  PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
52
  PROMPT_TEMPLATES = {
@@ -354,10 +360,13 @@ def render_sidebar():
354
  st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
355
 
356
  with st.expander("Generative Config"):
357
- st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
358
  st.slider(
359
  "temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
360
  )
 
 
 
361
  st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
362
 
363
  with st.expander("Retrieval Config"):
@@ -424,9 +433,12 @@ def render_query_tab():
424
  .assign(answer=prompt | llm | StrOutputParser())
425
  )
426
 
427
- with get_openai_callback() as cb:
 
 
 
 
428
  SS["out"] = rag_chain.invoke(SS["query"])
429
- SS["cb"] = cb
430
 
431
  if "out" in SS:
432
 
@@ -438,9 +450,11 @@ def render_query_tab():
438
  with st.container(border=True):
439
  st.write("Response")
440
  st.info(out_display)
441
- with st.container(border=True):
442
- st.write("API Usage")
443
- st.warning(SS["cb"])
 
 
444
 
445
  with st.container(border=True):
446
  doc_grps = group_docs(SS["out"]["docs"])
@@ -486,12 +500,26 @@ with st.sidebar:
486
  render_sidebar()
487
 
488
 
489
- llm = ChatOpenAI(
490
- model_name=SS["model_name"],
491
- temperature=SS["temperature"],
492
- openai_api_key=st.secrets["openai_api_key"],
493
- model_kwargs={"top_p": SS["top_p"], "seed": SEED},
494
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  vectorstore = load_pinecone_vectorstore()
496
  format_docs = DOC_FORMATTERS[SS["prompt_version"]]
497
 
 
17
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
18
  from langchain_community.vectorstores.utils import DistanceStrategy
19
  from langchain_openai import ChatOpenAI
20
+ from langchain_anthropic import ChatAnthropic
21
  from langchain_pinecone import PineconeVectorStore
22
  from pinecone import Pinecone
23
  import streamlit as st
 
47
  "gpt-3.5-turbo-0125",
48
  "gpt-4-0125-preview",
49
  ]
50
+ ANTHROPIC_CHAT_MODELS = [
51
+ "claude-3-opus-20240229",
52
+ "claude-3-sonnet-20240229",
53
+ # "claude-3-haiku-20240229",
54
+ ]
55
+ CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS
56
 
57
  PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
58
  PROMPT_TEMPLATES = {
 
360
  st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
361
 
362
  with st.expander("Generative Config"):
363
+ st.selectbox(label="model name", options=CHAT_MODELS, key="model_name")
364
  st.slider(
365
  "temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
366
  )
367
+ st.slider(
368
+ "max_output_tokens", min_value=512, max_value=1024, key="max_output_tokens"
369
+ )
370
  st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
371
 
372
  with st.expander("Retrieval Config"):
 
433
  .assign(answer=prompt | llm | StrOutputParser())
434
  )
435
 
436
+ if SS["model_name"] in OPENAI_CHAT_MODELS:
437
+ with get_openai_callback() as cb:
438
+ SS["out"] = rag_chain.invoke(SS["query"])
439
+ SS["cb"] = cb
440
+ else:
441
  SS["out"] = rag_chain.invoke(SS["query"])
 
442
 
443
  if "out" in SS:
444
 
 
450
  with st.container(border=True):
451
  st.write("Response")
452
  st.info(out_display)
453
+
454
+ if SS["model_name"] in OPENAI_CHAT_MODELS:
455
+ with st.container(border=True):
456
+ st.write("API Usage")
457
+ st.warning(SS["cb"])
458
 
459
  with st.container(border=True):
460
  doc_grps = group_docs(SS["out"]["docs"])
 
500
  render_sidebar()
501
 
502
 
503
+ if SS["model_name"] in OPENAI_CHAT_MODELS:
504
+ llm = ChatOpenAI(
505
+ model_name=SS["model_name"],
506
+ temperature=SS["temperature"],
507
+ openai_api_key=st.secrets["openai_api_key"],
508
+ model_kwargs={"top_p": SS["top_p"], "seed": SEED},
509
+ max_tokens=SS["max_output_tokens"],
510
+ )
511
+ elif SS["model_name"] in ANTHROPIC_CHAT_MODELS:
512
+ llm = ChatAnthropic(
513
+ model_name=SS["model_name"],
514
+ temperature=SS["temperature"],
515
+ anthropic_api_key=st.secrets["anthropic_api_key"],
516
+ top_p=SS["top_p"],
517
+ max_tokens_to_sample=SS["max_output_tokens"],
518
+ )
519
+ else:
520
+ raise ValueError()
521
+
522
+
523
  vectorstore = load_pinecone_vectorstore()
524
  format_docs = DOC_FORMATTERS[SS["prompt_version"]]
525
 
requirements.txt CHANGED
@@ -2,17 +2,22 @@ aiohttp==3.9.3
2
  aiosignal==1.3.1
3
  altair==5.2.0
4
  annotated-types==0.6.0
 
5
  anyio==4.3.0
 
6
  async-timeout==4.0.3
7
  attrs==23.2.0
 
8
  blinker==1.7.0
9
  cachetools==5.3.2
10
  certifi==2024.2.2
11
  charset-normalizer==3.3.2
12
  click==8.1.7
13
  dataclasses-json==0.6.4
 
14
  distro==1.9.0
15
  exceptiongroup==1.2.0
 
16
  filelock==3.13.1
17
  frozenlist==1.4.1
18
  fsspec==2024.2.0
@@ -24,12 +29,15 @@ httpx==0.27.0
24
  huggingface-hub==0.20.3
25
  idna==3.6
26
  importlib-metadata==7.0.1
 
 
27
  Jinja2==3.1.3
28
  joblib==1.3.2
29
  jsonpatch==1.33
30
  jsonpointer==2.4
31
  jsonschema==4.21.1
32
  jsonschema-specifications==2023.12.1
 
33
  langchain-community==0.0.24
34
  langchain-core==0.1.26
35
  langchain-openai==0.0.7
@@ -38,6 +46,7 @@ langsmith==0.1.7
38
  markdown-it-py==3.0.0
39
  MarkupSafe==2.1.5
40
  marshmallow==3.20.2
 
41
  mdurl==0.1.2
42
  mpmath==1.3.0
43
  multidict==6.0.5
@@ -48,9 +57,16 @@ openai==1.12.0
48
  orjson==3.9.15
49
  packaging==23.2
50
  pandas==2.2.1
 
 
 
51
  pillow==10.2.0
52
  pinecone-client==3.1.0
 
 
53
  protobuf==4.25.3
 
 
54
  pyarrow==15.0.0
55
  pydantic==2.6.2
56
  pydantic_core==2.16.3
@@ -72,6 +88,7 @@ six==1.16.0
72
  smmap==5.0.1
73
  sniffio==1.3.0
74
  SQLAlchemy==2.0.27
 
75
  streamlit==1.31.1
76
  sympy==1.12
77
  tenacity==8.2.3
@@ -79,10 +96,12 @@ threadpoolctl==3.3.0
79
  tiktoken==0.6.0
80
  tokenizers==0.15.2
81
  toml==0.10.2
 
82
  toolz==0.12.1
83
  torch==2.2.1
84
  tornado==6.4
85
  tqdm==4.66.2
 
86
  transformers==4.38.1
87
  typing-inspect==0.9.0
88
  typing_extensions==4.9.0
@@ -90,5 +109,7 @@ tzdata==2024.1
90
  tzlocal==5.2
91
  urllib3==2.2.1
92
  validators==0.22.0
 
 
93
  yarl==1.9.4
94
  zipp==3.17.0
 
2
  aiosignal==1.3.1
3
  altair==5.2.0
4
  annotated-types==0.6.0
5
+ anthropic==0.18.1
6
  anyio==4.3.0
7
+ asttokens==2.4.1
8
  async-timeout==4.0.3
9
  attrs==23.2.0
10
+ black==24.2.0
11
  blinker==1.7.0
12
  cachetools==5.3.2
13
  certifi==2024.2.2
14
  charset-normalizer==3.3.2
15
  click==8.1.7
16
  dataclasses-json==0.6.4
17
+ decorator==5.1.1
18
  distro==1.9.0
19
  exceptiongroup==1.2.0
20
+ executing==2.0.1
21
  filelock==3.13.1
22
  frozenlist==1.4.1
23
  fsspec==2024.2.0
 
29
  huggingface-hub==0.20.3
30
  idna==3.6
31
  importlib-metadata==7.0.1
32
+ ipython==8.22.1
33
+ jedi==0.19.1
34
  Jinja2==3.1.3
35
  joblib==1.3.2
36
  jsonpatch==1.33
37
  jsonpointer==2.4
38
  jsonschema==4.21.1
39
  jsonschema-specifications==2023.12.1
40
+ langchain-anthropic==0.1.1
41
  langchain-community==0.0.24
42
  langchain-core==0.1.26
43
  langchain-openai==0.0.7
 
46
  markdown-it-py==3.0.0
47
  MarkupSafe==2.1.5
48
  marshmallow==3.20.2
49
+ matplotlib-inline==0.1.6
50
  mdurl==0.1.2
51
  mpmath==1.3.0
52
  multidict==6.0.5
 
57
  orjson==3.9.15
58
  packaging==23.2
59
  pandas==2.2.1
60
+ parso==0.8.3
61
+ pathspec==0.12.1
62
+ pexpect==4.9.0
63
  pillow==10.2.0
64
  pinecone-client==3.1.0
65
+ platformdirs==4.2.0
66
+ prompt-toolkit==3.0.43
67
  protobuf==4.25.3
68
+ ptyprocess==0.7.0
69
+ pure-eval==0.2.2
70
  pyarrow==15.0.0
71
  pydantic==2.6.2
72
  pydantic_core==2.16.3
 
88
  smmap==5.0.1
89
  sniffio==1.3.0
90
  SQLAlchemy==2.0.27
91
+ stack-data==0.6.3
92
  streamlit==1.31.1
93
  sympy==1.12
94
  tenacity==8.2.3
 
96
  tiktoken==0.6.0
97
  tokenizers==0.15.2
98
  toml==0.10.2
99
+ tomli==2.0.1
100
  toolz==0.12.1
101
  torch==2.2.1
102
  tornado==6.4
103
  tqdm==4.66.2
104
+ traitlets==5.14.1
105
  transformers==4.38.1
106
  typing-inspect==0.9.0
107
  typing_extensions==4.9.0
 
109
  tzlocal==5.2
110
  urllib3==2.2.1
111
  validators==0.22.0
112
+ watchdog==4.0.0
113
+ wcwidth==0.2.13
114
  yarl==1.9.4
115
  zipp==3.17.0