RomyMy's picture
update readme
3a91716
import numpy as np
import redis
import streamlit as st
from langchain import HuggingFaceHub
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from redis.commands.search.query import Query
from sentence_transformers import SentenceTransformer
from constants import (
EMBEDDING_MODEL_NAME,
FALCON_MAX_TOKENS,
FALCON_REPO_ID,
FALCON_TEMPERATURE,
HUGGINGFACEHUB_API_TOKEN,
ITEM_KEYWORD_EMBEDDING,
OPENAI_API_KEY,
OPENAI_MODEL_NAME,
OPENAI_TEMPERATURE,
TEMPLATE_1,
TEMPLATE_2,
TOPK,
)
from database import create_redis
# connect to redis database
@st.cache_resource()
def connect_to_redis():
pool = create_redis()
return redis.Redis(connection_pool=pool)
# the encoding keywords chain
@st.cache_resource()
def encode_keywords_chain():
llm = HuggingFaceHub(
repo_id=FALCON_REPO_ID,
model_kwargs={"temperature": FALCON_TEMPERATURE, "max_new_tokens": FALCON_MAX_TOKENS},
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
)
prompt = PromptTemplate(
input_variables=["product_description"],
template=TEMPLATE_1,
)
chain = LLMChain(llm=llm, prompt=prompt)
return chain
# the present products chain
def present_products_chain():
template = TEMPLATE_2
memory = ConversationBufferMemory(memory_key="chat_history")
prompt = PromptTemplate(input_variables=["chat_history", "user_msg"], template=template)
chain = LLMChain(
llm=ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=OPENAI_TEMPERATURE, model=OPENAI_MODEL_NAME),
prompt=prompt,
verbose=False,
memory=memory,
)
return chain
@st.cache_resource()
def instance_embedding_model():
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
return embedding_model
def main():
st.title("My Amazon shopping buddy 🏷️")
st.caption("πŸ€– Powered by Falcon Open Source AI model")
redis_conn = connect_to_redis()
keywords_chain = encode_keywords_chain()
if "window_refreshed" not in st.session_state:
st.session_state.window_refreshed = True
st.session_state.present_products_chain = present_products_chain()
embedding_model = instance_embedding_model()
if "messages" not in st.session_state:
st.session_state["messages"] = [
{"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"}
]
for msg in st.session_state["messages"]:
st.chat_message(msg["role"]).write(msg["content"])
prompt = st.chat_input(key="user_input")
if prompt:
st.session_state["messages"].append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
st.session_state.disabled = True
keywords = keywords_chain.run(prompt)
# vectorize the query
query_vector = embedding_model.encode(keywords)
query_vector_bytes = np.array(query_vector).astype(np.float32).tobytes()
# prepare the query
q = (
Query(f"*=>[KNN {TOPK} @{ITEM_KEYWORD_EMBEDDING} $vec_param AS vector_score]")
.sort_by("vector_score")
.paging(0, TOPK)
.return_fields("vector_score", "item_name", "item_keywords")
.dialect(2)
)
params_dict = {"vec_param": query_vector_bytes}
# Execute the query
results = redis_conn.ft().search(q, query_params=params_dict)
result_output = ""
for product in results.docs:
result_output += f"product_name:{product.item_name}, product_description:{product.item_keywords} \n"
result = st.session_state.present_products_chain.predict(user_msg=f"{result_output}\n{prompt}")
st.session_state.messages.append({"role": "assistant", "content": result})
st.chat_message("assistant").write(result)
if __name__ == "__main__":
main()