File size: 4,031 Bytes
c423312
6f8d992
c423312
6f8d992
 
 
c423312
 
 
 
6f8d992
c423312
 
 
 
 
4304dbd
 
 
c423312
 
 
 
4304dbd
c423312
 
6f8d992
c423312
116461b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4304dbd
116461b
 
 
 
 
c423312
116461b
 
 
 
 
 
 
 
c423312
 
 
 
4304dbd
 
 
430ccfd
4304dbd
c423312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a91716
c423312
 
 
 
 
 
 
 
430ccfd
c423312
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import redis
import streamlit as st
from langchain import HuggingFaceHub
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from redis.commands.search.query import Query
from sentence_transformers import SentenceTransformer

from constants import (
    EMBEDDING_MODEL_NAME,
    FALCON_MAX_TOKENS,
    FALCON_REPO_ID,
    FALCON_TEMPERATURE,
    HUGGINGFACEHUB_API_TOKEN,
    ITEM_KEYWORD_EMBEDDING,
    OPENAI_API_KEY,
    OPENAI_MODEL_NAME,
    OPENAI_TEMPERATURE,
    TEMPLATE_1,
    TEMPLATE_2,
    TOPK,
)
from database import create_redis


# connect to redis database
@st.cache_resource()
def connect_to_redis():
    pool = create_redis()
    return redis.Redis(connection_pool=pool)


# the encoding keywords chain
@st.cache_resource()
def encode_keywords_chain():
    llm = HuggingFaceHub(
        repo_id=FALCON_REPO_ID,
        model_kwargs={"temperature": FALCON_TEMPERATURE, "max_new_tokens": FALCON_MAX_TOKENS},
        huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
    )
    prompt = PromptTemplate(
        input_variables=["product_description"],
        template=TEMPLATE_1,
    )
    chain = LLMChain(llm=llm, prompt=prompt)
    return chain


# the present products chain
def present_products_chain():
    template = TEMPLATE_2
    memory = ConversationBufferMemory(memory_key="chat_history")
    prompt = PromptTemplate(input_variables=["chat_history", "user_msg"], template=template)
    chain = LLMChain(
        llm=ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=OPENAI_TEMPERATURE, model=OPENAI_MODEL_NAME),
        prompt=prompt,
        verbose=False,
        memory=memory,
    )
    return chain


@st.cache_resource()
def instance_embedding_model():
    embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
    return embedding_model


def main():
    st.title("My Amazon shopping buddy 🏷️")
    st.caption("πŸ€– Powered by Falcon Open Source AI model")
    redis_conn = connect_to_redis()
    keywords_chain = encode_keywords_chain()

    if "window_refreshed" not in st.session_state:
        st.session_state.window_refreshed = True
        st.session_state.present_products_chain = present_products_chain()

    embedding_model = instance_embedding_model()

    if "messages" not in st.session_state:
        st.session_state["messages"] = [
            {"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"}
        ]
    for msg in st.session_state["messages"]:
        st.chat_message(msg["role"]).write(msg["content"])

    prompt = st.chat_input(key="user_input")

    if prompt:
        st.session_state["messages"].append({"role": "user", "content": prompt})
        st.chat_message("user").write(prompt)
        st.session_state.disabled = True
        keywords = keywords_chain.run(prompt)
        # vectorize the query
        query_vector = embedding_model.encode(keywords)
        query_vector_bytes = np.array(query_vector).astype(np.float32).tobytes()
        # prepare the query
        q = (
            Query(f"*=>[KNN {TOPK} @{ITEM_KEYWORD_EMBEDDING} $vec_param AS vector_score]")
            .sort_by("vector_score")
            .paging(0, TOPK)
            .return_fields("vector_score", "item_name", "item_keywords")
            .dialect(2)
        )
        params_dict = {"vec_param": query_vector_bytes}
        # Execute the query
        results = redis_conn.ft().search(q, query_params=params_dict)
        result_output = ""
        for product in results.docs:
            result_output += f"product_name:{product.item_name}, product_description:{product.item_keywords} \n"
        result = st.session_state.present_products_chain.predict(user_msg=f"{result_output}\n{prompt}")
        st.session_state.messages.append({"role": "assistant", "content": result})
        st.chat_message("assistant").write(result)


if __name__ == "__main__":
    main()