Spaces:
Running
Running
gabrielaltay
commited on
Commit
•
c6e2641
1
Parent(s):
df92eb1
add claude
Browse files- app.py +41 -13
- requirements.txt +21 -0
app.py
CHANGED
@@ -17,6 +17,7 @@ from langchain_community.callbacks import get_openai_callback
|
|
17 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
18 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
19 |
from langchain_openai import ChatOpenAI
|
|
|
20 |
from langchain_pinecone import PineconeVectorStore
|
21 |
from pinecone import Pinecone
|
22 |
import streamlit as st
|
@@ -46,7 +47,12 @@ OPENAI_CHAT_MODELS = [
|
|
46 |
"gpt-3.5-turbo-0125",
|
47 |
"gpt-4-0125-preview",
|
48 |
]
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
|
52 |
PROMPT_TEMPLATES = {
|
@@ -354,10 +360,13 @@ def render_sidebar():
|
|
354 |
st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
|
355 |
|
356 |
with st.expander("Generative Config"):
|
357 |
-
st.selectbox(label="model name", options=
|
358 |
st.slider(
|
359 |
"temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
|
360 |
)
|
|
|
|
|
|
|
361 |
st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
|
362 |
|
363 |
with st.expander("Retrieval Config"):
|
@@ -424,9 +433,12 @@ def render_query_tab():
|
|
424 |
.assign(answer=prompt | llm | StrOutputParser())
|
425 |
)
|
426 |
|
427 |
-
|
|
|
|
|
|
|
|
|
428 |
SS["out"] = rag_chain.invoke(SS["query"])
|
429 |
-
SS["cb"] = cb
|
430 |
|
431 |
if "out" in SS:
|
432 |
|
@@ -438,9 +450,11 @@ def render_query_tab():
|
|
438 |
with st.container(border=True):
|
439 |
st.write("Response")
|
440 |
st.info(out_display)
|
441 |
-
|
442 |
-
|
443 |
-
st.
|
|
|
|
|
444 |
|
445 |
with st.container(border=True):
|
446 |
doc_grps = group_docs(SS["out"]["docs"])
|
@@ -486,12 +500,26 @@ with st.sidebar:
|
|
486 |
render_sidebar()
|
487 |
|
488 |
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
vectorstore = load_pinecone_vectorstore()
|
496 |
format_docs = DOC_FORMATTERS[SS["prompt_version"]]
|
497 |
|
|
|
17 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
18 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
19 |
from langchain_openai import ChatOpenAI
|
20 |
+
from langchain_anthropic import ChatAnthropic
|
21 |
from langchain_pinecone import PineconeVectorStore
|
22 |
from pinecone import Pinecone
|
23 |
import streamlit as st
|
|
|
47 |
"gpt-3.5-turbo-0125",
|
48 |
"gpt-4-0125-preview",
|
49 |
]
|
50 |
+
ANTHROPIC_CHAT_MODELS = [
|
51 |
+
"claude-3-opus-20240229",
|
52 |
+
"claude-3-sonnet-20240229",
|
53 |
+
# "claude-3-haiku-20240229",
|
54 |
+
]
|
55 |
+
CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS
|
56 |
|
57 |
PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
|
58 |
PROMPT_TEMPLATES = {
|
|
|
360 |
st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
|
361 |
|
362 |
with st.expander("Generative Config"):
|
363 |
+
st.selectbox(label="model name", options=CHAT_MODELS, key="model_name")
|
364 |
st.slider(
|
365 |
"temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
|
366 |
)
|
367 |
+
st.slider(
|
368 |
+
"max_output_tokens", min_value=512, max_value=1024, key="max_output_tokens"
|
369 |
+
)
|
370 |
st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
|
371 |
|
372 |
with st.expander("Retrieval Config"):
|
|
|
433 |
.assign(answer=prompt | llm | StrOutputParser())
|
434 |
)
|
435 |
|
436 |
+
if SS["model_name"] in OPENAI_CHAT_MODELS:
|
437 |
+
with get_openai_callback() as cb:
|
438 |
+
SS["out"] = rag_chain.invoke(SS["query"])
|
439 |
+
SS["cb"] = cb
|
440 |
+
else:
|
441 |
SS["out"] = rag_chain.invoke(SS["query"])
|
|
|
442 |
|
443 |
if "out" in SS:
|
444 |
|
|
|
450 |
with st.container(border=True):
|
451 |
st.write("Response")
|
452 |
st.info(out_display)
|
453 |
+
|
454 |
+
if SS["model_name"] in OPENAI_CHAT_MODELS:
|
455 |
+
with st.container(border=True):
|
456 |
+
st.write("API Usage")
|
457 |
+
st.warning(SS["cb"])
|
458 |
|
459 |
with st.container(border=True):
|
460 |
doc_grps = group_docs(SS["out"]["docs"])
|
|
|
500 |
render_sidebar()
|
501 |
|
502 |
|
503 |
+
if SS["model_name"] in OPENAI_CHAT_MODELS:
|
504 |
+
llm = ChatOpenAI(
|
505 |
+
model_name=SS["model_name"],
|
506 |
+
temperature=SS["temperature"],
|
507 |
+
openai_api_key=st.secrets["openai_api_key"],
|
508 |
+
model_kwargs={"top_p": SS["top_p"], "seed": SEED},
|
509 |
+
max_tokens=SS["max_output_tokens"],
|
510 |
+
)
|
511 |
+
elif SS["model_name"] in ANTHROPIC_CHAT_MODELS:
|
512 |
+
llm = ChatAnthropic(
|
513 |
+
model_name=SS["model_name"],
|
514 |
+
temperature=SS["temperature"],
|
515 |
+
anthropic_api_key=st.secrets["anthropic_api_key"],
|
516 |
+
top_p=SS["top_p"],
|
517 |
+
max_tokens_to_sample=SS["max_output_tokens"],
|
518 |
+
)
|
519 |
+
else:
|
520 |
+
raise ValueError()
|
521 |
+
|
522 |
+
|
523 |
vectorstore = load_pinecone_vectorstore()
|
524 |
format_docs = DOC_FORMATTERS[SS["prompt_version"]]
|
525 |
|
requirements.txt
CHANGED
@@ -2,17 +2,22 @@ aiohttp==3.9.3
|
|
2 |
aiosignal==1.3.1
|
3 |
altair==5.2.0
|
4 |
annotated-types==0.6.0
|
|
|
5 |
anyio==4.3.0
|
|
|
6 |
async-timeout==4.0.3
|
7 |
attrs==23.2.0
|
|
|
8 |
blinker==1.7.0
|
9 |
cachetools==5.3.2
|
10 |
certifi==2024.2.2
|
11 |
charset-normalizer==3.3.2
|
12 |
click==8.1.7
|
13 |
dataclasses-json==0.6.4
|
|
|
14 |
distro==1.9.0
|
15 |
exceptiongroup==1.2.0
|
|
|
16 |
filelock==3.13.1
|
17 |
frozenlist==1.4.1
|
18 |
fsspec==2024.2.0
|
@@ -24,12 +29,15 @@ httpx==0.27.0
|
|
24 |
huggingface-hub==0.20.3
|
25 |
idna==3.6
|
26 |
importlib-metadata==7.0.1
|
|
|
|
|
27 |
Jinja2==3.1.3
|
28 |
joblib==1.3.2
|
29 |
jsonpatch==1.33
|
30 |
jsonpointer==2.4
|
31 |
jsonschema==4.21.1
|
32 |
jsonschema-specifications==2023.12.1
|
|
|
33 |
langchain-community==0.0.24
|
34 |
langchain-core==0.1.26
|
35 |
langchain-openai==0.0.7
|
@@ -38,6 +46,7 @@ langsmith==0.1.7
|
|
38 |
markdown-it-py==3.0.0
|
39 |
MarkupSafe==2.1.5
|
40 |
marshmallow==3.20.2
|
|
|
41 |
mdurl==0.1.2
|
42 |
mpmath==1.3.0
|
43 |
multidict==6.0.5
|
@@ -48,9 +57,16 @@ openai==1.12.0
|
|
48 |
orjson==3.9.15
|
49 |
packaging==23.2
|
50 |
pandas==2.2.1
|
|
|
|
|
|
|
51 |
pillow==10.2.0
|
52 |
pinecone-client==3.1.0
|
|
|
|
|
53 |
protobuf==4.25.3
|
|
|
|
|
54 |
pyarrow==15.0.0
|
55 |
pydantic==2.6.2
|
56 |
pydantic_core==2.16.3
|
@@ -72,6 +88,7 @@ six==1.16.0
|
|
72 |
smmap==5.0.1
|
73 |
sniffio==1.3.0
|
74 |
SQLAlchemy==2.0.27
|
|
|
75 |
streamlit==1.31.1
|
76 |
sympy==1.12
|
77 |
tenacity==8.2.3
|
@@ -79,10 +96,12 @@ threadpoolctl==3.3.0
|
|
79 |
tiktoken==0.6.0
|
80 |
tokenizers==0.15.2
|
81 |
toml==0.10.2
|
|
|
82 |
toolz==0.12.1
|
83 |
torch==2.2.1
|
84 |
tornado==6.4
|
85 |
tqdm==4.66.2
|
|
|
86 |
transformers==4.38.1
|
87 |
typing-inspect==0.9.0
|
88 |
typing_extensions==4.9.0
|
@@ -90,5 +109,7 @@ tzdata==2024.1
|
|
90 |
tzlocal==5.2
|
91 |
urllib3==2.2.1
|
92 |
validators==0.22.0
|
|
|
|
|
93 |
yarl==1.9.4
|
94 |
zipp==3.17.0
|
|
|
2 |
aiosignal==1.3.1
|
3 |
altair==5.2.0
|
4 |
annotated-types==0.6.0
|
5 |
+
anthropic==0.18.1
|
6 |
anyio==4.3.0
|
7 |
+
asttokens==2.4.1
|
8 |
async-timeout==4.0.3
|
9 |
attrs==23.2.0
|
10 |
+
black==24.2.0
|
11 |
blinker==1.7.0
|
12 |
cachetools==5.3.2
|
13 |
certifi==2024.2.2
|
14 |
charset-normalizer==3.3.2
|
15 |
click==8.1.7
|
16 |
dataclasses-json==0.6.4
|
17 |
+
decorator==5.1.1
|
18 |
distro==1.9.0
|
19 |
exceptiongroup==1.2.0
|
20 |
+
executing==2.0.1
|
21 |
filelock==3.13.1
|
22 |
frozenlist==1.4.1
|
23 |
fsspec==2024.2.0
|
|
|
29 |
huggingface-hub==0.20.3
|
30 |
idna==3.6
|
31 |
importlib-metadata==7.0.1
|
32 |
+
ipython==8.22.1
|
33 |
+
jedi==0.19.1
|
34 |
Jinja2==3.1.3
|
35 |
joblib==1.3.2
|
36 |
jsonpatch==1.33
|
37 |
jsonpointer==2.4
|
38 |
jsonschema==4.21.1
|
39 |
jsonschema-specifications==2023.12.1
|
40 |
+
langchain-anthropic==0.1.1
|
41 |
langchain-community==0.0.24
|
42 |
langchain-core==0.1.26
|
43 |
langchain-openai==0.0.7
|
|
|
46 |
markdown-it-py==3.0.0
|
47 |
MarkupSafe==2.1.5
|
48 |
marshmallow==3.20.2
|
49 |
+
matplotlib-inline==0.1.6
|
50 |
mdurl==0.1.2
|
51 |
mpmath==1.3.0
|
52 |
multidict==6.0.5
|
|
|
57 |
orjson==3.9.15
|
58 |
packaging==23.2
|
59 |
pandas==2.2.1
|
60 |
+
parso==0.8.3
|
61 |
+
pathspec==0.12.1
|
62 |
+
pexpect==4.9.0
|
63 |
pillow==10.2.0
|
64 |
pinecone-client==3.1.0
|
65 |
+
platformdirs==4.2.0
|
66 |
+
prompt-toolkit==3.0.43
|
67 |
protobuf==4.25.3
|
68 |
+
ptyprocess==0.7.0
|
69 |
+
pure-eval==0.2.2
|
70 |
pyarrow==15.0.0
|
71 |
pydantic==2.6.2
|
72 |
pydantic_core==2.16.3
|
|
|
88 |
smmap==5.0.1
|
89 |
sniffio==1.3.0
|
90 |
SQLAlchemy==2.0.27
|
91 |
+
stack-data==0.6.3
|
92 |
streamlit==1.31.1
|
93 |
sympy==1.12
|
94 |
tenacity==8.2.3
|
|
|
96 |
tiktoken==0.6.0
|
97 |
tokenizers==0.15.2
|
98 |
toml==0.10.2
|
99 |
+
tomli==2.0.1
|
100 |
toolz==0.12.1
|
101 |
torch==2.2.1
|
102 |
tornado==6.4
|
103 |
tqdm==4.66.2
|
104 |
+
traitlets==5.14.1
|
105 |
transformers==4.38.1
|
106 |
typing-inspect==0.9.0
|
107 |
typing_extensions==4.9.0
|
|
|
109 |
tzlocal==5.2
|
110 |
urllib3==2.2.1
|
111 |
validators==0.22.0
|
112 |
+
watchdog==4.0.0
|
113 |
+
wcwidth==0.2.13
|
114 |
yarl==1.9.4
|
115 |
zipp==3.17.0
|