rairo commited on
Commit
dc2e0ab
1 Parent(s): 33f1573

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .env +1 -0
  2. README.md +2 -8
  3. demo.py +166 -0
  4. requirements.txt +114 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ PALM="AIzaSyBVojf3nBKO_UITOwZtDVyAejW_2Qne1KY"
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Sonitycom
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: sonitycom
3
+ app_file: demo.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.7.1
 
 
6
  ---
 
 
demo.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import time
4
+ from sentence_transformers import SentenceTransformer
5
+ from redis.commands.search.field import VectorField
6
+ from redis.commands.search.field import TextField
7
+ from redis.commands.search.field import TagField
8
+ from redis.commands.search.query import Query
9
+ import redis
10
+ from tqdm import tqdm
11
+ import google.generativeai as palm
12
+ import pandas as pd
13
+ from langchain.chains import LLMChain
14
+
15
+
16
+ from langchain.prompts import PromptTemplate
17
+
18
+ import os
19
+
20
+ import gradio as gr
21
+ import io
22
+
23
+ from langchain.llms import GooglePalm
24
+ import pandas as pd
25
+ #from yolopandas import pd
26
+
27
+ from langchain.embeddings import GooglePalmEmbeddings
28
+ from langchain.memory import ConversationBufferMemory
29
+ from dotenv import load_dotenv
30
+
31
+ load_dotenv()
32
+
33
+ redis_conn = redis.Redis(
34
+ host='redis-15860.c322.us-east-1-2.ec2.cloud.redislabs.com',
35
+ port=15860,
36
+ password='PVnvSZI5nISPsrxxhCHZF3pfZWI7YAIG')
37
+
38
+ '''
39
+ df = pd.read_csv("coms3.csv")
40
+
41
+
42
+ print(list(df))
43
+
44
+ print(df['item_keywords'].sample(2))
45
+
46
+ company_metadata = df.to_dict(orient='index')
47
+
48
+
49
+ model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
50
+
51
+
52
+ item_keywords = [company_metadata[i]['item_keywords'] for i in company_metadata.keys()]
53
+ item_keywords_vectors = []
54
+ for sentence in tqdm(item_keywords):
55
+ s = model.encode(sentence)
56
+ item_keywords_vectors.append(s)
57
+
58
+ print(company_metadata[0])
59
+
60
+ def load_vectors(client, company_metadata, vector_dict, vector_field_name):
61
+ p = client.pipeline(transaction=False)
62
+ for index in company_metadata.keys():
63
+ #hash key
64
+ #print(index)
65
+ #print(company_metadata[index]['company_l_id'])
66
+ try:
67
+ key=str('company:'+ str(index)+ ':' + company_metadata[index]['primary_key'])
68
+ except:
69
+ print(key)
70
+ continue
71
+
72
+
73
+ #hash values
74
+ item_metadata = company_metadata[index]
75
+ item_keywords_vector = vector_dict[index].astype(np.float32).tobytes()
76
+ item_metadata[vector_field_name]=item_keywords_vector
77
+
78
+ # HSET
79
+ p.hset(key,mapping=item_metadata)
80
+
81
+ p.execute()
82
+
83
+ def create_flat_index (redis_conn,vector_field_name,number_of_vectors, vector_dimensions=512, distance_metric='L2'):
84
+ redis_conn.ft().create_index([
85
+ VectorField(vector_field_name, "FLAT", {"TYPE": "FLOAT32", "DIM": vector_dimensions, "DISTANCE_METRIC": distance_metric, "INITIAL_CAP": number_of_vectors, "BLOCK_SIZE":number_of_vectors }),
86
+ TagField("company_l_id"),
87
+ TextField("company_name"),
88
+ TextField("item_keywords"),
89
+ TagField("industry")
90
+ ])
91
+
92
+ ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector'
93
+ TEXT_EMBEDDING_DIMENSION=768
94
+ NUMBER_COMPANIES=1000
95
+
96
+ print ('Loading and Indexing + ' + str(NUMBER_COMPANIES) + 'companies')
97
+
98
+ #flush all data
99
+ redis_conn.flushall()
100
+
101
+ #create flat index & load vectors
102
+ create_flat_index(redis_conn, ITEM_KEYWORD_EMBEDDING_FIELD,NUMBER_COMPANIES,TEXT_EMBEDDING_DIMENSION,'COSINE')
103
+ load_vectors(redis_conn,company_metadata,item_keywords_vectors,ITEM_KEYWORD_EMBEDDING_FIELD)
104
+ '''
105
+ model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
106
+ ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector'
107
+ TEXT_EMBEDDING_DIMENSION=768
108
+ NUMBER_PRODUCTS=1000
109
+
110
+ prompt = PromptTemplate(
111
+ input_variables=["company_description"],
112
+ template='Create comma seperated company keywords to perform a query on a company dataset for this user input'
113
+ )
114
+
115
+ template = """You are a chatbot. Be kind, detailed and nice. Present the given queried search result in a nice way as answer to the user input. dont ask questions back! just take the given context
116
+
117
+ {chat_history}
118
+ Human: {user_question}
119
+ Chatbot:
120
+ """
121
+
122
+ prompt = PromptTemplate(
123
+ input_variables=["chat_history", "user_question"],
124
+ template=template
125
+ )
126
+ chat_history= ""
127
+ def answer(user_question):
128
+ llm = GooglePalm(temperature=0, google_api_key=os.environ['PALM'])
129
+ chain = LLMChain(llm=llm, prompt=prompt)
130
+ keywords = chain.run({'user_question':user_question, 'chat_history':chat_history})
131
+
132
+ topK=3
133
+ #vectorize the query
134
+ query_vector = model.encode(keywords).astype(np.float32).tobytes()
135
+
136
+ q = Query(f'*=>[KNN {topK} @{ITEM_KEYWORD_EMBEDDING_FIELD} $vec_param AS vector_score]').sort_by('vector_score').paging(0,topK).return_fields('vector_score','item_name','item_id','item_keywords').dialect(2)
137
+ params_dict = {"vec_param": query_vector}
138
+
139
+ #Execute the query
140
+ results = redis_conn.ft().search(q, query_params = params_dict)
141
+
142
+ full_result_string = ''
143
+ for company in results.docs:
144
+ full_result_string += company.company_name + ' ' + company.item_keywords + ' ' + company.company_l_id + "\n\n\n"
145
+
146
+ memory = ConversationBufferMemory(memory_key="chat_history")
147
+ llm_chain = LLMChain(
148
+ llm=llm,
149
+ prompt=prompt,
150
+ verbose=False,
151
+ memory=memory,
152
+ )
153
+
154
+
155
+ ans = llm_chain.predict(user_msg= f"{full_result_string} ---\n\n {user_question}")
156
+
157
+ return ans
158
+
159
+ demo = gr.Interface(
160
+
161
+ fn=answer,
162
+ inputs=["text"],
163
+ outputs=["text"],
164
+ title="Ask Sonity",
165
+ )
166
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.8.6
3
+ aiosignal==1.3.1
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ appdirs==1.4.4
7
+ async-timeout==4.0.3
8
+ attrs==23.1.0
9
+ beautifulsoup4==4.12.2
10
+ Brotli==1.1.0
11
+ cachetools==5.3.2
12
+ certifi==2023.7.22
13
+ charset-normalizer==3.3.2
14
+ click==8.1.7
15
+ dataclasses-json==0.6.2
16
+ duckduckgo-search==3.9.4
17
+ exceptiongroup==1.1.3
18
+ filelock==3.13.1
19
+ frozendict==2.3.8
20
+ frozenlist==1.4.0
21
+ fsspec==2023.10.0
22
+ google-ai-generativelanguage==0.1.0
23
+ google-api-core==2.14.0
24
+ google-auth==2.23.4
25
+ google-generativeai==0.1.0rc1
26
+ googleapis-common-protos==1.61.0
27
+ greenlet==3.0.1
28
+ grpcio==1.59.3
29
+ grpcio-status==1.59.3
30
+ h11==0.14.0
31
+ h2==4.1.0
32
+ hpack==4.0.0
33
+ html5lib==1.1
34
+ httpcore==1.0.2
35
+ httpx==0.25.1
36
+ huggingface-hub==0.19.3
37
+ hyperframe==6.0.1
38
+ idna==3.4
39
+ install==1.3.5
40
+ Jinja2==3.1.2
41
+ joblib==1.3.2
42
+ jsonpatch==1.33
43
+ jsonpointer==2.4
44
+ langchain==0.0.333
45
+ langsmith==0.0.63
46
+ lxml==4.9.3
47
+ MarkupSafe==2.1.3
48
+ marshmallow==3.20.1
49
+ mpmath==1.3.0
50
+ multidict==6.0.4
51
+ multitasking==0.0.11
52
+ mypy-extensions==1.0.0
53
+ networkx==3.1
54
+ nltk==3.8.1
55
+ numpy==1.24.4
56
+ nvidia-cublas-cu12==12.1.3.1
57
+ nvidia-cuda-cupti-cu12==12.1.105
58
+ nvidia-cuda-nvrtc-cu12==12.1.105
59
+ nvidia-cuda-runtime-cu12==12.1.105
60
+ nvidia-cudnn-cu12==8.9.2.26
61
+ nvidia-cufft-cu12==11.0.2.54
62
+ nvidia-curand-cu12==10.3.2.106
63
+ nvidia-cusolver-cu12==11.4.5.107
64
+ nvidia-cusparse-cu12==12.1.0.106
65
+ nvidia-nccl-cu12==2.18.1
66
+ nvidia-nvjitlink-cu12==12.3.52
67
+ nvidia-nvtx-cu12==12.1.105
68
+ packaging==23.2
69
+ pandas==2.0.3
70
+ peewee==3.17.0
71
+ Pillow==10.1.0
72
+ proto-plus==1.22.3
73
+ protobuf==4.25.1
74
+ psycopg2-binary==2.9.9
75
+ pyasn1==0.5.1
76
+ pyasn1-modules==0.3.0
77
+ pydantic==2.4.2
78
+ pydantic-core==2.10.1
79
+ python-dateutil==2.8.2
80
+ pytz==2023.3.post1
81
+ PyYAML==6.0.1
82
+ redis==5.0.1
83
+ regex==2023.10.3
84
+ requests==2.31.0
85
+ rsa==4.9
86
+ safetensors==0.4.0
87
+ scikit-learn==1.3.2
88
+ scipy==1.10.1
89
+ sentence-transformers==2.2.2
90
+ sentencepiece==0.1.99
91
+ simplejson==3.19.2
92
+ six==1.16.0
93
+ sniffio==1.3.0
94
+ socksio==1.0.0
95
+ soupsieve==2.5
96
+ SQLAlchemy==2.0.23
97
+ sympy==1.12
98
+ tenacity==8.2.3
99
+ threadpoolctl==3.2.0
100
+ tokenizers==0.15.0
101
+ torch==2.1.1
102
+ torchvision==0.16.1
103
+ tqdm==4.66.1
104
+ transformers==4.35.2
105
+ triton==2.1.0
106
+ typing-extensions==4.8.0
107
+ typing-inspect==0.9.0
108
+ tzdata==2023.3
109
+ urllib3==2.0.7
110
+ webencodings==0.5.1
111
+ wikipedia==1.4.0
112
+ yahoo-finance==1.4.0
113
+ yarl==1.9.2
114
+ yfinance==0.2.31