File size: 4,418 Bytes
6a7e3a3
 
 
 
e05602f
 
6a7e3a3
 
e05602f
6a7e3a3
 
 
e05602f
6a7e3a3
24e4881
6a7e3a3
24e4881
 
 
 
 
6a7e3a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from dataclasses import dataclass
from typing import Literal
import streamlit as st
import os
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Pinecone
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import streamlit.components.v1 as components
from langchain_groq import ChatGroq
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain_community.output_parsers.rail_parser import GuardrailsOutputParser
import time
from transformers import AutoModel

HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
from huggingface_hub import HfApi
api = HfApi()
api.set_access_token(HUGGINGFACEHUB_API_TOKEN)
user = api.whoami()

@dataclass
class Message:
    """Class for keeping track of a chat message."""
    origin: Literal["πŸ‘€ Human", "πŸ‘¨πŸ»β€βš–οΈ Ai"]
    message: str


def download_hugging_face_embeddings():
    embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
    return embeddings


def initialize_session_state():
    if "history" not in st.session_state:
        st.session_state.history = []
    if "conversation" not in st.session_state:
        chat = ChatGroq(temperature=0.5, groq_api_key=st.secrets["Groq_api"], model_name="mixtral-8x7b-32768")

        embeddings = download_hugging_face_embeddings()

        # Initializing Pinecone
        pinecone.init(
            api_key=st.secrets["PINECONE_API_KEY"],  # find at app.pinecone.io
            environment=st.secrets["PINECONE_API_ENV"]  # next to api key in console
        )
        index_name = "book-recommendations"  # updated index name for books

        docsearch = Pinecone.from_existing_index(index_name, embeddings)

        prompt_template = """
            You are an AI trained to recommend books. You will suggest books based on the user's preferences and previous likes. 
            Please provide insightful recommendations and explain why each book might be of interest to the user.
            Context: {context}
            User Preference: {question}
            Suggested Books:
            """

        PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        message_history = ChatMessageHistory()
        memory = ConversationBufferMemory(
            memory_key="chat_history",
            output_key="answer",
            chat_memory=message_history,
            return_messages=True,
            )
        retrieval_chain = ConversationalRetrievalChain.from_llm(llm=chat,
                                                      chain_type="recommendation",
                                                      retriever=docsearch.as_retriever(
                                                          search_kwargs={'k': 5}),
                                                      return_source_documents=True,
                                                      combine_docs_chain_kwargs={"prompt": PROMPT},
                                                      memory=memory
                                                     )

        st.session_state.conversation = retrieval_chain


def on_click_callback():
    human_prompt = st.session_state.human_prompt
    st.session_state.human_prompt=""
    response = st.session_state.conversation(
        human_prompt
    )
    llm_response = response['answer']
    st.session_state.history.append(
        Message("πŸ‘€ Human", human_prompt)
    )
    st.session_state.history.append(
        Message("πŸ‘¨πŸ»β€βš–οΈ Ai", llm_response)
    )

initialize_session_state()

st.title("AI Book Recommender")

st.markdown(
   """ 
   πŸ‘‹ **Welcome to the AI Book Recommender!**
   Share your favorite genres or books, and I'll recommend your next reads!
   """
)

chat_placeholder = st.container()
prompt_placeholder = st.form("chat-form")

with chat_placeholder:
    for chat in st.session_state.history:
        st.markdown(f"{chat.origin} : {chat.message}")

with prompt_placeholder:
    st.markdown("**Chat**")
    cols = st.columns((6, 1))
    cols[0].text_input(
        "Chat",
        label_visibility="collapsed",
        key="human_prompt",
    )
    cols[1].form_submit_button(
        "Submit",
        type="primary",
        on_click=on_click_callback,
    )