import numpy as np import redis import streamlit as st from langchain import HuggingFaceHub from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer from constants import ( EMBEDDING_MODEL_NAME, FALCON_MAX_TOKENS, FALCON_REPO_ID, FALCON_TEMPERATURE, HUGGINGFACEHUB_API_TOKEN, ITEM_KEYWORD_EMBEDDING, OPENAI_API_KEY, OPENAI_MODEL_NAME, OPENAI_TEMPERATURE, TEMPLATE_1, TEMPLATE_2, TOPK, ) from database import create_redis # connect to redis database @st.cache_resource() def connect_to_redis(): pool = create_redis() return redis.Redis(connection_pool=pool) # the encoding keywords chain @st.cache_resource() def encode_keywords_chain(): llm = HuggingFaceHub( repo_id=FALCON_REPO_ID, model_kwargs={"temperature": FALCON_TEMPERATURE, "max_new_tokens": FALCON_MAX_TOKENS}, huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, ) prompt = PromptTemplate( input_variables=["product_description"], template=TEMPLATE_1, ) chain = LLMChain(llm=llm, prompt=prompt) return chain # the present products chain def present_products_chain(): template = TEMPLATE_2 memory = ConversationBufferMemory(memory_key="chat_history") prompt = PromptTemplate(input_variables=["chat_history", "user_msg"], template=template) chain = LLMChain( llm=ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=OPENAI_TEMPERATURE, model=OPENAI_MODEL_NAME), prompt=prompt, verbose=False, memory=memory, ) return chain @st.cache_resource() def instance_embedding_model(): embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) return embedding_model def main(): st.title("My Amazon shopping buddy 🏷️") st.caption("🤖 Powered by Falcon Open Source AI model") redis_conn = connect_to_redis() keywords_chain = encode_keywords_chain() if "window_refreshed" not in st.session_state: st.session_state.window_refreshed = True st.session_state.present_products_chain = present_products_chain() embedding_model = instance_embedding_model() if "messages" not in st.session_state: st.session_state["messages"] = [ {"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"} ] for msg in st.session_state["messages"]: st.chat_message(msg["role"]).write(msg["content"]) prompt = st.chat_input(key="user_input") if prompt: st.session_state["messages"].append({"role": "user", "content": prompt}) st.chat_message("user").write(prompt) st.session_state.disabled = True keywords = keywords_chain.run(prompt) # vectorize the query query_vector = embedding_model.encode(keywords) query_vector_bytes = np.array(query_vector).astype(np.float32).tobytes() # prepare the query q = ( Query(f"*=>[KNN {TOPK} @{ITEM_KEYWORD_EMBEDDING} $vec_param AS vector_score]") .sort_by("vector_score") .paging(0, TOPK) .return_fields("vector_score", "item_name", "item_keywords") .dialect(2) ) params_dict = {"vec_param": query_vector_bytes} # Execute the query results = redis_conn.ft().search(q, query_params=params_dict) result_output = "" for product in results.docs: result_output += f"product_name:{product.item_name}, product_description:{product.item_keywords} \n" result = st.session_state.present_products_chain.predict(user_msg=f"{result_output}\n{prompt}") st.session_state.messages.append({"role": "assistant", "content": result}) st.chat_message("assistant").write(result) if __name__ == "__main__": main()