taratrankennedy commited on
Commit
6a7e3a3
β€’
1 Parent(s): 77e9475

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+ import streamlit as st
4
+ import os
5
+ from llamaapi import LlamaAPI
6
+ from langchain_experimental.llms import ChatLlamaAPI
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ import pinecone
9
+ from langchain.vectorstores import Pinecone
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.chains import RetrievalQA
12
+ import streamlit.components.v1 as components
13
+ from langchain_groq import ChatGroq
14
+ from langchain.chains import ConversationalRetrievalChain
15
+ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
16
+ import time
17
+
18
+ HUGGINGFACEHUB_API_TOKEN = st.secrets['HUGGINGFACEHUB_API_TOKEN']
19
+
20
+ @dataclass
21
+ class Message:
22
+ """Class for keeping track of a chat message."""
23
+ origin: Literal["πŸ‘€ Human", "πŸ‘¨πŸ»β€βš–οΈ Ai"]
24
+ message: str
25
+
26
+
27
+ def download_hugging_face_embeddings():
28
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
29
+ return embeddings
30
+
31
+
32
+ def initialize_session_state():
33
+ if "history" not in st.session_state:
34
+ st.session_state.history = []
35
+ if "conversation" not in st.session_state:
36
+ chat = ChatGroq(temperature=0.5, groq_api_key=st.secrets["Groq_api"], model_name="mixtral-8x7b-32768")
37
+
38
+ embeddings = download_hugging_face_embeddings()
39
+
40
+ # Initializing Pinecone
41
+ pinecone.init(
42
+ api_key=st.secrets["PINECONE_API_KEY"], # find at app.pinecone.io
43
+ environment=st.secrets["PINECONE_API_ENV"] # next to api key in console
44
+ )
45
+ index_name = "book-recommendations" # updated index name for books
46
+
47
+ docsearch = Pinecone.from_existing_index(index_name, embeddings)
48
+
49
+ prompt_template = """
50
+ You are an AI trained to recommend books. You will suggest books based on the user's preferences and previous likes.
51
+ Please provide insightful recommendations and explain why each book might be of interest to the user.
52
+ Context: {context}
53
+ User Preference: {question}
54
+ Suggested Books:
55
+ """
56
+
57
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
58
+
59
+ message_history = ChatMessageHistory()
60
+ memory = ConversationBufferMemory(
61
+ memory_key="chat_history",
62
+ output_key="answer",
63
+ chat_memory=message_history,
64
+ return_messages=True,
65
+ )
66
+ retrieval_chain = ConversationalRetrievalChain.from_llm(llm=chat,
67
+ chain_type="recommendation",
68
+ retriever=docsearch.as_retriever(
69
+ search_kwargs={'k': 5}),
70
+ return_source_documents=True,
71
+ combine_docs_chain_kwargs={"prompt": PROMPT},
72
+ memory=memory
73
+ )
74
+
75
+ st.session_state.conversation = retrieval_chain
76
+
77
+
78
+ def on_click_callback():
79
+ human_prompt = st.session_state.human_prompt
80
+ st.session_state.human_prompt=""
81
+ response = st.session_state.conversation(
82
+ human_prompt
83
+ )
84
+ llm_response = response['answer']
85
+ st.session_state.history.append(
86
+ Message("πŸ‘€ Human", human_prompt)
87
+ )
88
+ st.session_state.history.append(
89
+ Message("πŸ‘¨πŸ»β€βš–οΈ Ai", llm_response)
90
+ )
91
+
92
+ initialize_session_state()
93
+
94
+ st.title("AI Book Recommender")
95
+
96
+ st.markdown(
97
+ """
98
+ πŸ‘‹ **Welcome to the AI Book Recommender!**
99
+ Share your favorite genres or books, and I'll recommend your next reads!
100
+ """
101
+ )
102
+
103
+ chat_placeholder = st.container()
104
+ prompt_placeholder = st.form("chat-form")
105
+
106
+ with chat_placeholder:
107
+ for chat in st.session_state.history:
108
+ st.markdown(f"{chat.origin} : {chat.message}")
109
+
110
+ with prompt_placeholder:
111
+ st.markdown("**Chat**")
112
+ cols = st.columns((6, 1))
113
+ cols[0].text_input(
114
+ "Chat",
115
+ label_visibility="collapsed",
116
+ key="human_prompt",
117
+ )
118
+ cols[1].form_submit_button(
119
+ "Submit",
120
+ type="primary",
121
+ on_click=on_click_callback,
122
+ )