Spaces:
Runtime error
Runtime error
Upload 8 files
Browse files- .gitattributes +1 -0
- app.py +28 -0
- env.py +9 -0
- output.csv +0 -0
- output.db +3 -0
- ques_1.py +27 -0
- ques_2.py +51 -0
- requirements.txt +85 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
output.db filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from ques_1 import driver1
|
3 |
+
# from ques_2 import driver2
|
4 |
+
|
5 |
+
q1 = gr.Interface(
|
6 |
+
fn=driver1,
|
7 |
+
inputs=gr.Textbox(
|
8 |
+
label="Enter your industry to know more about the customers that use salesforce: ",
|
9 |
+
info="Food, Health care, Sports, etc...",
|
10 |
+
lines=1,
|
11 |
+
),
|
12 |
+
outputs="text",
|
13 |
+
)
|
14 |
+
|
15 |
+
q2 = gr.Interface(
|
16 |
+
fn=driver1,
|
17 |
+
inputs=gr.Textbox(
|
18 |
+
label="Enter your customer to know more about how they leverage salesforce",
|
19 |
+
info="Williams-Sonoma Inc., ReserveBar, Christy Sports, etc...",
|
20 |
+
lines=1,
|
21 |
+
),
|
22 |
+
outputs="text",
|
23 |
+
)
|
24 |
+
|
25 |
+
app = gr.TabbedInterface([q1, q2], ["Industry based Customers", "Customer stories leveraging Salesforce"])
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
app.launch()
|
env.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENAI_API_KEY = "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"
|
2 |
+
# snowflake_account = "JGIYHFR.BW66671"
|
3 |
+
# snowflake_account = "jgiyhfr-bw66671"
|
4 |
+
# username = "srini047"
|
5 |
+
# password = "Football1234%"
|
6 |
+
# database = "salesforce-scraped"
|
7 |
+
# schema = "public"
|
8 |
+
# warehouse = "compute_wh"
|
9 |
+
# role = "accountadmin"
|
output.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
output.db
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:701848626db7db581c8cee7699f7e813143fe100ddbdd49dd10b666008bdd14a
|
3 |
+
size 1265664
|
ques_1.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import OpenAI, SQLDatabase
|
2 |
+
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
|
3 |
+
from langchain.agents import create_sql_agent
|
4 |
+
from langchain.agents import AgentExecutor
|
5 |
+
from langchain.agents.agent_types import AgentType
|
6 |
+
from env import OPENAI_API_KEY
|
7 |
+
|
8 |
+
dburi = "sqlite:///output.db"
|
9 |
+
db = SQLDatabase.from_uri(dburi)
|
10 |
+
llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
|
11 |
+
|
12 |
+
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
13 |
+
agent_executor = create_sql_agent(
|
14 |
+
llm=llm,
|
15 |
+
toolkit=toolkit,
|
16 |
+
verbose=True,
|
17 |
+
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def driver1(industry_type):
|
22 |
+
response = agent_executor.run(
|
23 |
+
"What are the customers in"
|
24 |
+
+ industry_type
|
25 |
+
+ " industry that chose salesforce? Give the answer in form of bullet points ->"
|
26 |
+
)
|
27 |
+
return response
|
ques_2.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.document_loaders.csv_loader import CSVLoader
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from langchain.vectorstores import Chroma
|
4 |
+
from langchain.embeddings import OpenAIEmbeddings
|
5 |
+
from langchain.chat_models import ChatOpenAI
|
6 |
+
from langchain.schema.runnable import RunnablePassthrough
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain import hub
|
9 |
+
from env import OPENAI_API_KEY
|
10 |
+
|
11 |
+
def main():
|
12 |
+
loader = CSVLoader(file_path="output.csv")
|
13 |
+
data = loader.load()
|
14 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
|
15 |
+
splits = text_splitter.split_documents(data)
|
16 |
+
vectorstore = Chroma.from_documents(
|
17 |
+
documents=splits,
|
18 |
+
embedding=OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY),
|
19 |
+
)
|
20 |
+
retriever = vectorstore.as_retriever()
|
21 |
+
rag_prompt = hub.pull("rlm/rag-prompt")
|
22 |
+
llm = ChatOpenAI(
|
23 |
+
model_name="gpt-3.5-turbo",
|
24 |
+
temperature=0,
|
25 |
+
openai_api_key=OPENAI_API_KEY,
|
26 |
+
)
|
27 |
+
template = """Use the following pieces of context to answer the question at the end.
|
28 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
29 |
+
Use three sentences maximum and keep the answer as concise as possible.
|
30 |
+
{context}
|
31 |
+
Question: {question}
|
32 |
+
Helpful Answer:"""
|
33 |
+
rag_prompt_custom = PromptTemplate.from_template(template)
|
34 |
+
|
35 |
+
rag_chain = (
|
36 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
37 |
+
| rag_prompt_custom
|
38 |
+
| llm
|
39 |
+
)
|
40 |
+
return rag_chain
|
41 |
+
|
42 |
+
|
43 |
+
def driver2(customer_name):
|
44 |
+
rag_chain = main()
|
45 |
+
response = rag_chain.invoke(
|
46 |
+
"Can you tell me more about customer"
|
47 |
+
+ customer_name
|
48 |
+
+ " and how they benefited from salesforce?"
|
49 |
+
)
|
50 |
+
print(response)
|
51 |
+
return response.content
|
requirements.txt
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohttp==3.9.1
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.2.0
|
5 |
+
annotated-types==0.6.0
|
6 |
+
anyio==3.7.1
|
7 |
+
async-timeout==4.0.3
|
8 |
+
attrs==23.1.0
|
9 |
+
certifi==2023.11.17
|
10 |
+
charset-normalizer==3.3.2
|
11 |
+
click==8.1.7
|
12 |
+
colorama==0.4.6
|
13 |
+
contourpy==1.2.0
|
14 |
+
cycler==0.12.1
|
15 |
+
dataclasses-json==0.6.3
|
16 |
+
distro==1.8.0
|
17 |
+
exceptiongroup==1.2.0
|
18 |
+
fastapi==0.104.1
|
19 |
+
ffmpy==0.3.1
|
20 |
+
filelock==3.13.1
|
21 |
+
fonttools==4.45.1
|
22 |
+
frozenlist==1.4.0
|
23 |
+
fsspec==2023.10.0
|
24 |
+
gradio==4.7.1
|
25 |
+
gradio_client==0.7.0
|
26 |
+
h11==0.14.0
|
27 |
+
httpcore==1.0.2
|
28 |
+
httpx==0.25.2
|
29 |
+
huggingface-hub==0.19.4
|
30 |
+
idna==3.6
|
31 |
+
importlib-resources==6.1.1
|
32 |
+
Jinja2==3.1.2
|
33 |
+
jsonpatch==1.33
|
34 |
+
jsonpointer==2.4
|
35 |
+
jsonschema==4.20.0
|
36 |
+
jsonschema-specifications==2023.11.2
|
37 |
+
kiwisolver==1.4.5
|
38 |
+
langchain==0.0.344
|
39 |
+
langchain-core==0.0.8
|
40 |
+
langsmith==0.0.68
|
41 |
+
markdown-it-py==3.0.0
|
42 |
+
MarkupSafe==2.1.3
|
43 |
+
marshmallow==3.20.1
|
44 |
+
matplotlib==3.8.2
|
45 |
+
mdurl==0.1.2
|
46 |
+
multidict==6.0.4
|
47 |
+
mypy-extensions==1.0.0
|
48 |
+
numpy==1.26.2
|
49 |
+
openai==1.3.7
|
50 |
+
orjson==3.9.10
|
51 |
+
packaging==23.2
|
52 |
+
pandas==2.1.3
|
53 |
+
Pillow==10.1.0
|
54 |
+
pydantic==2.5.2
|
55 |
+
pydantic_core==2.14.5
|
56 |
+
pydub==0.25.1
|
57 |
+
Pygments==2.17.2
|
58 |
+
pyparsing==3.1.1
|
59 |
+
python-dateutil==2.8.2
|
60 |
+
python-multipart==0.0.6
|
61 |
+
pytz==2023.3.post1
|
62 |
+
PyYAML==6.0.1
|
63 |
+
referencing==0.31.1
|
64 |
+
requests==2.31.0
|
65 |
+
rich==13.7.0
|
66 |
+
rpds-py==0.13.2
|
67 |
+
semantic-version==2.10.0
|
68 |
+
shellingham==1.5.4
|
69 |
+
six==1.16.0
|
70 |
+
sniffio==1.3.0
|
71 |
+
SQLAlchemy==2.0.23
|
72 |
+
starlette==0.27.0
|
73 |
+
tenacity==8.2.3
|
74 |
+
tomlkit==0.12.0
|
75 |
+
toolz==0.12.0
|
76 |
+
tqdm==4.66.1
|
77 |
+
typer==0.9.0
|
78 |
+
typing-inspect==0.9.0
|
79 |
+
typing_extensions==4.8.0
|
80 |
+
tzdata==2023.3
|
81 |
+
urllib3==2.1.0
|
82 |
+
uvicorn==0.24.0.post1
|
83 |
+
websockets==11.0.3
|
84 |
+
yarl==1.9.3
|
85 |
+
zipp==3.17.0
|