RomyMy commited on
Commit
c423312
β€’
1 Parent(s): fa8eee4

fix imports

Browse files
Files changed (8) hide show
  1. .env_example +3 -1
  2. .pre-commit-config.yaml +24 -0
  3. app.py +115 -118
  4. constants.py +14 -0
  5. database.py +10 -7
  6. preprocess.py +30 -34
  7. utilities.py +0 -32
  8. utils.py +38 -0
.env_example CHANGED
@@ -1,3 +1,5 @@
1
  REDIS_KEY = ''
2
  OPENAI_API_KEY = ''
3
- HUGGINGFACEHUB_API_TOKEN = ''
 
 
 
1
  REDIS_KEY = ''
2
  OPENAI_API_KEY = ''
3
+ HUGGINGFACEHUB_API_TOKEN = ''
4
+ REDIS_HOST = ''
5
+ REDIS_PORT = ''
.pre-commit-config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v3.2.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: check-yaml
8
+ - id: check-added-large-files
9
+ - repo: https://github.com/psf/black
10
+ rev: 22.10.0
11
+ hooks:
12
+ - id: black
13
+ args: ["--line-length=118"]
14
+ - repo: https://github.com/pycqa/isort
15
+ rev: 5.12.0
16
+ hooks:
17
+ - id: isort
18
+ name: isort (python)
19
+ args: ["--profile", "black", "--filter-files"]
20
+ - repo: https://github.com/pycqa/flake8
21
+ rev: 6.0.0
22
+ hooks:
23
+ - id: flake8
24
+ args: ["--max-line-length=118", "--ignore=E501,E266,E203,W503"]
app.py CHANGED
@@ -1,124 +1,121 @@
1
- import streamlit as st
2
- from sentence_transformers import SentenceTransformer
3
- from redis.commands.search.query import Query
4
  import redis
5
- from langchain.prompts import PromptTemplate
 
6
  from langchain import HuggingFaceHub
7
  from langchain.chains import LLMChain
8
- from langchain.memory import ConversationBufferMemory
9
  from langchain.chat_models import ChatOpenAI
10
- from langchain.callbacks.base import BaseCallbackHandler
11
- import os
12
- from dotenv import load_dotenv
13
- import numpy as np
14
-
15
- load_dotenv()
16
- redis_key = os.getenv('REDIS_KEY')
17
- HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
18
- repo_id = 'tiiuae/falcon-7b-instruct'
19
-
20
- class StreamHandler(BaseCallbackHandler):
21
- def __init__(self, container, initial_text="", display_method='markdown'):
22
- self.container = container
23
- self.text = initial_text
24
- self.display_method = display_method
25
-
26
- def on_llm_new_token(self, token: str, **kwargs) -> None:
27
- self.text += token + " "
28
- display_function = getattr(self.container, self.display_method, None)
29
- if display_function is not None:
30
- display_function(self.text)
31
- else:
32
- raise ValueError(f"Invalid display_method: {self.display_method}")
33
-
34
-
35
- st.title('My Amazon shopping buddy 🏷️')
36
- st.caption('πŸ€– Powered by Falcon Open Source AI model')
37
-
38
- #connect to redis database
39
- @st.cache_resource()
40
- def redis_connect():
41
- redis_conn = redis.Redis(
42
- host='redis-12882.c259.us-central1-2.gce.cloud.redislabs.com',
43
- port=12882,
44
- password=redis_key)
45
- return redis_conn
46
-
47
- redis_conn = redis_connect()
48
-
49
- #the encoding keywords chain
50
- @st.cache_resource()
51
- def encode_keywords_chain():
52
- falcon_llm_1 = HuggingFaceHub(repo_id = repo_id, model_kwargs={'temperature':0.1,'max_new_tokens':500},huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN)
53
- prompt = PromptTemplate(
54
- input_variables=["product_description"],
55
- template="Create comma seperated product keywords to perform a query on a amazon dataset for this user input: {product_description}",
56
- )
57
- chain = LLMChain(llm=falcon_llm_1, prompt=prompt)
58
- return chain
59
- chain = encode_keywords_chain()
60
- #the present products chain
61
-
62
- @st.cache_resource()
63
- def present_products_chain():
64
- template = """You are a salesman. Be kind, detailed and nice. take the given context and Present the given queried search result in a nice way as answer to the user_msg. dont ask questions back or freestyle and invent followup conversation!
65
- {chat_history}
66
- user:{user_msg}
67
- Chatbot:"""
68
- prompt = PromptTemplate(
69
- input_variables=["chat_history", "user_msg"],
70
- template=template
71
- )
72
- memory = ConversationBufferMemory(memory_key="chat_history")
73
- llm_chain = LLMChain(
74
- llm = ChatOpenAI(openai_api_key=os.getenv('OPENAI_API_KEY'),temperature=0.8,model='gpt-3.5-turbo'),
75
- prompt=prompt,
76
- verbose=False,
77
- memory=memory,
78
- )
79
- return llm_chain
80
-
81
-
82
-
83
- llm_chain = present_products_chain()
84
-
85
- @st.cache_resource()
86
- def embedding_model():
87
- embedding_model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
88
- return embedding_model
89
-
90
- embedding_model = embedding_model()
91
-
92
- if "messages" not in st.session_state:
93
- st.session_state["messages"] = [{"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"}]
94
- for msg in st.session_state["messages"]:
95
- st.chat_message(msg["role"]).write(msg["content"])
96
-
97
- prompt = st.chat_input(key="user_input" )
98
-
99
- if prompt:
100
- st.session_state["messages"].append({"role": "user", "content": prompt})
101
- st.chat_message('user').write(prompt)
102
- st.session_state.disabled = True
103
- keywords = chain.run(prompt)
104
- #vectorize the query
105
- query_vector = embedding_model.encode(keywords)
106
- query_vector = np.array(query_vector).astype(np.float32).tobytes()
107
- #prepare the query
108
- ITEM_KEYWORD_EMBEDDING_FIELD = 'item_vector'
109
- topK=5
110
- q = Query(f'*=>[KNN {topK} @{ITEM_KEYWORD_EMBEDDING_FIELD} $vec_param AS vector_score]').sort_by('vector_score').paging(0,topK).return_fields('vector_score','item_name','item_id','item_keywords').dialect(2)
111
- params_dict = {"vec_param": query_vector}
112
- #Execute the query
113
- results = redis_conn.ft().search(q, query_params = params_dict)
114
-
115
- full_result_string = ''
116
- for product in results.docs:
117
- full_result_string += product.item_name + ' ' + product.item_keywords + "\n\n\n"
118
-
119
- result = llm_chain.predict(user_msg=f"{full_result_string} ---\n\n {prompt}")
120
- st.session_state.messages.append({"role": "assistant", "content": result})
121
- st.chat_message('assistant').write(result)
122
-
123
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
  import redis
5
+ import streamlit as st
6
+ from dotenv import load_dotenv
7
  from langchain import HuggingFaceHub
8
  from langchain.chains import LLMChain
 
9
  from langchain.chat_models import ChatOpenAI
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.prompts import PromptTemplate
12
+ from redis.commands.search.query import Query
13
+ from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ from constants import (
16
+ EMBEDDING_MODEL_NAME,
17
+ FALCON_MAX_TOKENS,
18
+ FALCON_REPO_ID,
19
+ FALCON_TEMPERATURE,
20
+ OPENAI_MODEL_NAME,
21
+ OPENAI_TEMPERATURE,
22
+ TEMPLATE_1,
23
+ TEMPLATE_2,
24
+ )
25
+ from database import create_redis
26
 
27
+ load_dotenv()
28
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
29
+ ITEM_KEYWORD_EMBEDDING = "item_vector"
30
+ TOPK = 5
31
+
32
+
33
+ def main():
34
+ # connect to redis database
35
+ @st.cache_resource()
36
+ def connect_to_redis():
37
+ pool = create_redis()
38
+ return redis.Redis(connection_pool=pool)
39
+
40
+ # the encoding keywords chain
41
+ @st.cache_resource()
42
+ def encode_keywords_chain():
43
+ falcon_llm_1 = HuggingFaceHub(
44
+ repo_id=FALCON_REPO_ID,
45
+ model_kwargs={"temperature": FALCON_TEMPERATURE, "max_new_tokens": FALCON_MAX_TOKENS},
46
+ huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
47
+ )
48
+ prompt = PromptTemplate(
49
+ input_variables=["product_description"],
50
+ template=TEMPLATE_1,
51
+ )
52
+ chain = LLMChain(llm=falcon_llm_1, prompt=prompt)
53
+ return chain
54
+
55
+ # the present products chain
56
+ @st.cache_resource()
57
+ def present_products_chain():
58
+ template = TEMPLATE_2
59
+ prompt = PromptTemplate(input_variables=["chat_history", "user_msg"], template=template)
60
+ memory = ConversationBufferMemory(memory_key="chat_history")
61
+ llm_chain = LLMChain(
62
+ llm=ChatOpenAI(
63
+ openai_api_key=os.getenv("OPENAI_API_KEY"), temperature=OPENAI_TEMPERATURE, model=OPENAI_MODEL_NAME
64
+ ),
65
+ prompt=prompt,
66
+ verbose=False,
67
+ memory=memory,
68
+ )
69
+ return llm_chain
70
+
71
+ @st.cache_resource()
72
+ def instance_embedding_model():
73
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
74
+ return embedding_model
75
+
76
+ st.title("My Amazon shopping buddy 🏷️")
77
+ st.caption("πŸ€– Powered by Falcon Open Source AI model")
78
+ redis_conn = connect_to_redis()
79
+ keywords_chain = encode_keywords_chain()
80
+ chat_chain = present_products_chain()
81
+ embedding_model = instance_embedding_model()
82
+
83
+ if "messages" not in st.session_state:
84
+ st.session_state["messages"] = [
85
+ {"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"}
86
+ ]
87
+ for msg in st.session_state["messages"]:
88
+ st.chat_message(msg["role"]).write(msg["content"])
89
+
90
+ prompt = st.chat_input(key="user_input")
91
+
92
+ if prompt:
93
+ st.session_state["messages"].append({"role": "user", "content": prompt})
94
+ st.chat_message("user").write(prompt)
95
+ st.session_state.disabled = True
96
+ keywords = keywords_chain.run(prompt)
97
+ # vectorize the query
98
+ query_vector = embedding_model.encode(keywords)
99
+ query_vector_bytes = np.array(query_vector).astype(np.float32).tobytes()
100
+ # prepare the query
101
+
102
+ q = (
103
+ Query(f"*=>[KNN {TOPK} @{ITEM_KEYWORD_EMBEDDING} $vec_param AS vector_score]")
104
+ .sort_by("vector_score")
105
+ .paging(0, TOPK)
106
+ .return_fields("vector_score", "item_name", "item_id", "item_keywords")
107
+ .dialect(2)
108
+ )
109
+ params_dict = {"vec_param": query_vector_bytes}
110
+ # Execute the query
111
+ results = redis_conn.ft().search(q, query_params=params_dict)
112
+ result_output = ""
113
+ for product in results.docs:
114
+ result_output += f"product_name:{product.item_name}, product_description:{product.item_keywords} \n"
115
+ result = chat_chain.predict(user_msg=f"{result_output}\n{prompt}")
116
+ st.session_state.messages.append({"role": "assistant", "content": result})
117
+ st.chat_message("assistant").write(result)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
constants.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FALCON_REPO_ID = "tiiuae/falcon-7b-instruct"
2
+ FALCON_TEMPERATURE = 0.1
3
+ FALCON_MAX_TOKENS = 500
4
+
5
+ OPENAI_MODEL_NAME = "gpt-3.5-turbo"
6
+ OPENAI_TEMPERATURE = 0.8
7
+
8
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-distilroberta-v1"
9
+
10
+ TEMPLATE_1 = "Create comma seperated product keywords to perform a query on a amazon dataset for this user input: {product_description}"
11
+ TEMPLATE_2 = """You are a salesman.Present the given product results in a nice way as answer to the user_msg. Dont ask questions back,
12
+ {chat_history}
13
+ user:{user_msg}
14
+ Chatbot:"""
database.py CHANGED
@@ -1,13 +1,16 @@
1
- import redis
2
  import os
 
 
3
  from dotenv import load_dotenv
4
 
5
  load_dotenv()
6
- redis_key = os.getenv('REDIS_KEY')
7
-
8
 
9
 
10
- redis_conn = redis.Redis(
11
- host='redis-12882.c259.us-central1-2.gce.cloud.redislabs.com',
12
- port=12882,
13
- password=redis_key)
 
 
 
 
 
 
1
  import os
2
+
3
+ import redis
4
  from dotenv import load_dotenv
5
 
6
  load_dotenv()
 
 
7
 
8
 
9
+ def create_redis():
10
+ return redis.ConnectionPool(
11
+ host=os.getenv("REDIS_HOST"),
12
+ port=os.getenv("REDIS_PORT"),
13
+ password=os.getenv("REDIS_KEY"),
14
+ db=0,
15
+ decode_responses=True,
16
+ )
preprocess.py CHANGED
@@ -1,48 +1,44 @@
1
- from langchain.embeddings import OpenAIEmbeddings
2
- from sentence_transformers import SentenceTransformer
3
- import os
4
- import pandas as pd
5
  import numpy as np
6
- from dotenv import load_dotenv
7
- from database import redis_conn
8
- from utilities import create_flat_index, load_vectors
9
-
10
 
 
 
11
 
12
- #set maximum length for text fields
 
 
13
  MAX_TEXT_LENGTH = 512
 
 
 
14
 
15
- def auto_truncate(text:str):
16
  return text[0:MAX_TEXT_LENGTH]
17
 
18
- data = pd.read_csv('product_data.csv',converters={'bullet_point':auto_truncate,'item_keywords':auto_truncate,'item_name':auto_truncate})
19
- data['primary_key'] = data['item_id'] + '-' + data['domain_name']
20
- data.drop(columns=['item_id','domain_name'],inplace=True)
21
- data['item_keywords'].replace('',np.nan,inplace=True)
22
- data.dropna(subset=['item_keywords'],inplace=True)
 
 
 
 
23
  data.reset_index(drop=True, inplace=True)
24
- data_metadata = data.head(500).to_dict(orient='index')
25
 
26
- #generating embeddings (vectors) for the item keywords
27
- embedding_model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
28
  # embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
29
 
30
- #get the item keywords attribute for each product and encode them into vector embeddings
31
- item_keywords = [data_metadata[i]['item_keywords'] for i in data_metadata.keys()]
32
  item_keywords_vectors = [embedding_model.encode(item) for item in item_keywords]
33
 
34
- TEXT_EMBEDDING_DIMENSION=768
35
- NUMBER_PRODUCTS=500
36
-
37
- print ('Loading and Indexing + ' + str(NUMBER_PRODUCTS) + ' products')
38
- #flush all data
39
  redis_conn.flushall()
40
- #create flat index & load vectors
41
- create_flat_index(redis_conn,NUMBER_PRODUCTS,TEXT_EMBEDDING_DIMENSION,'COSINE')
42
- load_vectors(redis_conn,data_metadata,item_keywords_vectors)
43
-
44
-
45
-
46
-
47
-
48
-
 
 
 
 
 
1
  import numpy as np
2
+ import pandas as pd
3
+ import redis
4
+ from sentence_transformers import SentenceTransformer
 
5
 
6
+ from database import create_redis
7
+ from utils import create_flat_index, load_vectors
8
 
9
+ pool = create_redis()
10
+ redis_conn = redis.Redis(connection_pool=pool)
11
+ # set maximum length for text fields
12
  MAX_TEXT_LENGTH = 512
13
+ TEXT_EMBEDDING_DIMENSION = 768
14
+ NUMBER_PRODUCTS = 10000
15
+
16
 
17
+ def auto_truncate(text: str):
18
  return text[0:MAX_TEXT_LENGTH]
19
 
20
+
21
+ data = pd.read_csv(
22
+ "product_data.csv",
23
+ converters={"bullet_point": auto_truncate, "item_keywords": auto_truncate, "item_name": auto_truncate},
24
+ )
25
+ data["primary_key"] = data["item_id"] + "-" + data["domain_name"]
26
+ data.drop(columns=["item_id", "domain_name"], inplace=True)
27
+ data["item_keywords"].replace("", np.nan, inplace=True)
28
+ data.dropna(subset=["item_keywords"], inplace=True)
29
  data.reset_index(drop=True, inplace=True)
30
+ data_metadata = data.head(10000).to_dict(orient="index")
31
 
32
+ # generating embeddings (vectors) for the item keywords
33
+ embedding_model = SentenceTransformer("sentence-transformers/all-distilroberta-v1")
34
  # embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key)
35
 
36
+ # get the item keywords attribute for each product and encode them into vector embeddings
37
+ item_keywords = [data_metadata[i]["item_keywords"] for i in data_metadata.keys()]
38
  item_keywords_vectors = [embedding_model.encode(item) for item in item_keywords]
39
 
40
+ # flush all data
 
 
 
 
41
  redis_conn.flushall()
42
+ # create flat index & load vectors
43
+ create_flat_index(redis_conn, NUMBER_PRODUCTS, TEXT_EMBEDDING_DIMENSION, "COSINE")
44
+ load_vectors(redis_conn, data_metadata, item_keywords_vectors)
 
 
 
 
 
 
utilities.py DELETED
@@ -1,32 +0,0 @@
1
- from redis import Redis
2
- from redis.commands.search.field import VectorField
3
- from redis.commands.search.field import TextField
4
- from redis.commands.search.field import TagField
5
- from redis.commands.search.result import Result
6
- import numpy as np
7
-
8
- def load_vectors(client:Redis, product_metadata, vector_dict):
9
- p = client.pipeline(transaction=False)
10
- for index in product_metadata.keys():
11
- #hash key
12
- key='product:'+ str(index)+ ':' + product_metadata[index]['primary_key']
13
-
14
- #hash values
15
- item_metadata = product_metadata[index]
16
- item_keywords_vector = np.array(vector_dict[index], dtype=np.float32).tobytes()
17
- item_metadata['item_vector']=item_keywords_vector
18
-
19
- # HSET
20
- p.hset(key,mapping=item_metadata)
21
-
22
- p.execute()
23
-
24
- def create_flat_index (redis_conn, number_of_vectors, vector_dimensions=512, distance_metric='L2'):
25
- redis_conn.ft().create_index([
26
- VectorField('item_vector', "FLAT", {"TYPE": "FLOAT32", "DIM": vector_dimensions, "DISTANCE_METRIC": distance_metric, "INITIAL_CAP": number_of_vectors, "BLOCK_SIZE":number_of_vectors }),
27
- TagField("product_type"),
28
- TextField("item_name"),
29
- TextField("item_keywords"),
30
- TagField("country")
31
- ])
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from redis import Redis
3
+ from redis.commands.search.field import TagField, TextField, VectorField
4
+
5
+
6
+ def load_vectors(client: Redis, product_metadata, vector_dict):
7
+ p = client.pipeline(transaction=False)
8
+ for index in product_metadata.keys():
9
+ # hash key
10
+ key = "product:" + str(index) + ":" + product_metadata[index]["primary_key"]
11
+ # hash values
12
+ item_metadata = product_metadata[index]
13
+ item_keywords_vector = np.array(vector_dict[index], dtype=np.float32).tobytes()
14
+ item_metadata["item_vector"] = item_keywords_vector
15
+ p.hset(key, mapping=item_metadata)
16
+ p.execute()
17
+
18
+
19
+ def create_flat_index(redis_conn, number_of_vectors, vector_dimensions=512, distance_metric="L2"):
20
+ redis_conn.ft().create_index(
21
+ [
22
+ VectorField(
23
+ "item_vector",
24
+ "FLAT",
25
+ {
26
+ "TYPE": "FLOAT32",
27
+ "DIM": vector_dimensions,
28
+ "DISTANCE_METRIC": distance_metric,
29
+ "INITIAL_CAP": number_of_vectors,
30
+ "BLOCK_SIZE": number_of_vectors,
31
+ },
32
+ ),
33
+ TagField("product_type"),
34
+ TextField("item_name"),
35
+ TextField("item_keywords"),
36
+ TagField("country"),
37
+ ]
38
+ )