srini047 commited on
Commit
fde297f
1 Parent(s): 1093a4a

Upload 8 files

Browse files
Files changed (8) hide show
  1. .gitattributes +1 -0
  2. app.py +28 -0
  3. env.py +9 -0
  4. output.csv +0 -0
  5. output.db +3 -0
  6. ques_1.py +27 -0
  7. ques_2.py +51 -0
  8. requirements.txt +85 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output.db filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ques_1 import driver1
3
+ # from ques_2 import driver2
4
+
5
+ q1 = gr.Interface(
6
+ fn=driver1,
7
+ inputs=gr.Textbox(
8
+ label="Enter your industry to know more about the customers that use salesforce: ",
9
+ info="Food, Health care, Sports, etc...",
10
+ lines=1,
11
+ ),
12
+ outputs="text",
13
+ )
14
+
15
+ q2 = gr.Interface(
16
+ fn=driver1,
17
+ inputs=gr.Textbox(
18
+ label="Enter your customer to know more about how they leverage salesforce",
19
+ info="Williams-Sonoma Inc., ReserveBar, Christy Sports, etc...",
20
+ lines=1,
21
+ ),
22
+ outputs="text",
23
+ )
24
+
25
+ app = gr.TabbedInterface([q1, q2], ["Industry based Customers", "Customer stories leveraging Salesforce"])
26
+
27
+ if __name__ == "__main__":
28
+ app.launch()
env.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ OPENAI_API_KEY = "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"
2
+ # snowflake_account = "JGIYHFR.BW66671"
3
+ # snowflake_account = "jgiyhfr-bw66671"
4
+ # username = "srini047"
5
+ # password = "Football1234%"
6
+ # database = "salesforce-scraped"
7
+ # schema = "public"
8
+ # warehouse = "compute_wh"
9
+ # role = "accountadmin"
output.csv ADDED
The diff for this file is too large to render. See raw diff
 
output.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:701848626db7db581c8cee7699f7e813143fe100ddbdd49dd10b666008bdd14a
3
+ size 1265664
ques_1.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import OpenAI, SQLDatabase
2
+ from langchain.agents.agent_toolkits import SQLDatabaseToolkit
3
+ from langchain.agents import create_sql_agent
4
+ from langchain.agents import AgentExecutor
5
+ from langchain.agents.agent_types import AgentType
6
+ from env import OPENAI_API_KEY
7
+
8
+ dburi = "sqlite:///output.db"
9
+ db = SQLDatabase.from_uri(dburi)
10
+ llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
11
+
12
+ toolkit = SQLDatabaseToolkit(db=db, llm=llm)
13
+ agent_executor = create_sql_agent(
14
+ llm=llm,
15
+ toolkit=toolkit,
16
+ verbose=True,
17
+ agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
18
+ )
19
+
20
+
21
+ def driver1(industry_type):
22
+ response = agent_executor.run(
23
+ "What are the customers in"
24
+ + industry_type
25
+ + " industry that chose salesforce? Give the answer in form of bullet points ->"
26
+ )
27
+ return response
ques_2.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders.csv_loader import CSVLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.vectorstores import Chroma
4
+ from langchain.embeddings import OpenAIEmbeddings
5
+ from langchain.chat_models import ChatOpenAI
6
+ from langchain.schema.runnable import RunnablePassthrough
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain import hub
9
+ from env import OPENAI_API_KEY
10
+
11
+ def main():
12
+ loader = CSVLoader(file_path="output.csv")
13
+ data = loader.load()
14
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
15
+ splits = text_splitter.split_documents(data)
16
+ vectorstore = Chroma.from_documents(
17
+ documents=splits,
18
+ embedding=OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY),
19
+ )
20
+ retriever = vectorstore.as_retriever()
21
+ rag_prompt = hub.pull("rlm/rag-prompt")
22
+ llm = ChatOpenAI(
23
+ model_name="gpt-3.5-turbo",
24
+ temperature=0,
25
+ openai_api_key=OPENAI_API_KEY,
26
+ )
27
+ template = """Use the following pieces of context to answer the question at the end.
28
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
29
+ Use three sentences maximum and keep the answer as concise as possible.
30
+ {context}
31
+ Question: {question}
32
+ Helpful Answer:"""
33
+ rag_prompt_custom = PromptTemplate.from_template(template)
34
+
35
+ rag_chain = (
36
+ {"context": retriever, "question": RunnablePassthrough()}
37
+ | rag_prompt_custom
38
+ | llm
39
+ )
40
+ return rag_chain
41
+
42
+
43
+ def driver2(customer_name):
44
+ rag_chain = main()
45
+ response = rag_chain.invoke(
46
+ "Can you tell me more about customer"
47
+ + customer_name
48
+ + " and how they benefited from salesforce?"
49
+ )
50
+ print(response)
51
+ return response.content
requirements.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.1
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==3.7.1
7
+ async-timeout==4.0.3
8
+ attrs==23.1.0
9
+ certifi==2023.11.17
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ colorama==0.4.6
13
+ contourpy==1.2.0
14
+ cycler==0.12.1
15
+ dataclasses-json==0.6.3
16
+ distro==1.8.0
17
+ exceptiongroup==1.2.0
18
+ fastapi==0.104.1
19
+ ffmpy==0.3.1
20
+ filelock==3.13.1
21
+ fonttools==4.45.1
22
+ frozenlist==1.4.0
23
+ fsspec==2023.10.0
24
+ gradio==4.7.1
25
+ gradio_client==0.7.0
26
+ h11==0.14.0
27
+ httpcore==1.0.2
28
+ httpx==0.25.2
29
+ huggingface-hub==0.19.4
30
+ idna==3.6
31
+ importlib-resources==6.1.1
32
+ Jinja2==3.1.2
33
+ jsonpatch==1.33
34
+ jsonpointer==2.4
35
+ jsonschema==4.20.0
36
+ jsonschema-specifications==2023.11.2
37
+ kiwisolver==1.4.5
38
+ langchain==0.0.344
39
+ langchain-core==0.0.8
40
+ langsmith==0.0.68
41
+ markdown-it-py==3.0.0
42
+ MarkupSafe==2.1.3
43
+ marshmallow==3.20.1
44
+ matplotlib==3.8.2
45
+ mdurl==0.1.2
46
+ multidict==6.0.4
47
+ mypy-extensions==1.0.0
48
+ numpy==1.26.2
49
+ openai==1.3.7
50
+ orjson==3.9.10
51
+ packaging==23.2
52
+ pandas==2.1.3
53
+ Pillow==10.1.0
54
+ pydantic==2.5.2
55
+ pydantic_core==2.14.5
56
+ pydub==0.25.1
57
+ Pygments==2.17.2
58
+ pyparsing==3.1.1
59
+ python-dateutil==2.8.2
60
+ python-multipart==0.0.6
61
+ pytz==2023.3.post1
62
+ PyYAML==6.0.1
63
+ referencing==0.31.1
64
+ requests==2.31.0
65
+ rich==13.7.0
66
+ rpds-py==0.13.2
67
+ semantic-version==2.10.0
68
+ shellingham==1.5.4
69
+ six==1.16.0
70
+ sniffio==1.3.0
71
+ SQLAlchemy==2.0.23
72
+ starlette==0.27.0
73
+ tenacity==8.2.3
74
+ tomlkit==0.12.0
75
+ toolz==0.12.0
76
+ tqdm==4.66.1
77
+ typer==0.9.0
78
+ typing-inspect==0.9.0
79
+ typing_extensions==4.8.0
80
+ tzdata==2023.3
81
+ urllib3==2.1.0
82
+ uvicorn==0.24.0.post1
83
+ websockets==11.0.3
84
+ yarl==1.9.3
85
+ zipp==3.17.0