Chris4K commited on
Commit
fe6b125
1 Parent(s): 6bfe6c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py CHANGED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ ## BitsAndBytes
3
+ #####################################
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
+
7
+ model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded"
8
+
9
+ ###### other models:
10
+ # "Trelis/Llama-2-7b-chat-hf-sharded-bf16"
11
+ # "bn22/Mistral-7B-Instruct-v0.1-sharded"
12
+ # "HuggingFaceH4/zephyr-7b-beta"
13
+
14
+ # function for loading 4-bit quantized model
15
+ def load_quantized_model(model_name: str):
16
+ """
17
+ :param model_name: Name or path of the model to be loaded.
18
+ :return: Loaded quantized model.
19
+ """
20
+ bnb_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_use_double_quant=True,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_compute_dtype=torch.bfloat16
25
+ )
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_name,
29
+ load_in_4bit=True,
30
+ torch_dtype=torch.bfloat16,
31
+ quantization_config=bnb_config
32
+ )
33
+ return model
34
+
35
+ ##################################################
36
+ ## vs chat
37
+ ##################################################
38
+ import torch
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
40
+
41
+ from langchain_core.messages import AIMessage, HumanMessage
42
+ from langchain_community.document_loaders import WebBaseLoader
43
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
44
+ from langchain_community.vectorstores import Chroma
45
+
46
+ #from langchain_openai import OpenAIEmbeddings, ChatOpenAI
47
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
48
+ from langchain.vectorstores.faiss import FAISS
49
+
50
+
51
+ from dotenv import load_dotenv
52
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
53
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
54
+ from langchain.chains.combine_documents import create_stuff_documents_chain
55
+
56
+
57
+ load_dotenv()
58
+
59
+ def get_vectorstore_from_url(url):
60
+ # get the text in document form
61
+ loader = WebBaseLoader(url)
62
+ document = loader.load()
63
+
64
+ # split the document into chunks
65
+ text_splitter = RecursiveCharacterTextSplitter()
66
+ document_chunks = text_splitter.split_documents(document)
67
+ #######
68
+ '''
69
+ FAISS
70
+ A FAISS vector store containing the embeddings of the text chunks.
71
+ '''
72
+ model = "BAAI/bge-base-en-v1.5"
73
+ encode_kwargs = {
74
+ "normalize_embeddings": True
75
+ } # set True to compute cosine similarity
76
+ embeddings = HuggingFaceBgeEmbeddings(
77
+ model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
78
+ )
79
+ # load from disk
80
+ vector_store = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
81
+
82
+ #vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
83
+ vector_store = Chroma.from_documents(document_chunks, embeddings, persist_directory="./chroma_db")
84
+
85
+
86
+
87
+
88
+ print("-----")
89
+ print(vector_store.similarity_search("What is ALiBi?"))
90
+ print("-----")
91
+
92
+ #######
93
+ # create a vectorstore from the chunks
94
+
95
+ return vector_store
96
+
97
+
98
+
99
+
100
+
101
+ def get_context_retriever_chain(vector_store):
102
+
103
+ # specify model huggingface mode name
104
+ model_name = "anakin87/zephyr-7b-alpha-sharded"
105
+ # model_name = "bn22/Mistral-7B-Instruct-v0.1-sharded"
106
+
107
+ ###### other models:
108
+ # "Trelis/Llama-2-7b-chat-hf-sharded-bf16"
109
+ # "bn22/Mistral-7B-Instruct-v0.1-sharded"
110
+ # "HuggingFaceH4/zephyr-7b-beta"
111
+
112
+ # function for loading 4-bit quantized model
113
+
114
+
115
+ llm = load_quantized_model(model_name)
116
+
117
+ retriever = vector_store.as_retriever()
118
+
119
+ prompt = ChatPromptTemplate.from_messages([
120
+ MessagesPlaceholder(variable_name="chat_history"),
121
+ ("user", "{input}"),
122
+ ("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
123
+ ])
124
+
125
+ retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
126
+
127
+ return retriever_chain
128
+
129
+ def get_conversational_rag_chain(retriever_chain):
130
+
131
+ llm = load_quantized_model(model_name)
132
+
133
+ prompt = ChatPromptTemplate.from_messages([
134
+ ("system", "Answer the user's questions based on the below context:\n\n{context}"),
135
+ MessagesPlaceholder(variable_name="chat_history"),
136
+ ("user", "{input}"),
137
+ ])
138
+
139
+ stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
140
+
141
+ return create_retrieval_chain(retriever_chain, stuff_documents_chain)
142
+
143
+ def get_response(user_input):
144
+ retriever_chain = get_context_retriever_chain(st.session_state.vector_store)
145
+ conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
146
+
147
+ response = conversation_rag_chain.invoke({
148
+ "chat_history": st.session_state.chat_history,
149
+ "input": user_query
150
+ })
151
+
152
+ return response['answer']
153
+
154
+
155
+
156
+ ###################
157
+
158
+ ###################
159
+ import gradio as gr
160
+
161
+ ##from langchain_core.runnables.base import ChatPromptValue
162
+ #from torch import tensor
163
+
164
+ # Create Gradio interface
165
+ #vector_store = None # Set your vector store here
166
+ chat_history = [] # Set your chat history here
167
+
168
+ # Define your function here
169
+ def get_response(user_input):
170
+
171
+ # Define the prompt as a ChatPromptValue object
172
+ #user_input = ChatPromptValue(user_input)
173
+
174
+ # Convert the prompt to a tensor
175
+ #input_ids = user_input.tensor
176
+
177
+
178
+ #vs = get_vectorstore_from_url(user_url, all_domain)
179
+ vs = get_vectorstore_from_url("https://www.bofrost.de/shop/laenderkueche_5573/italienische-kueche_5576/linguine-mit-feinen-pilzen.html?position=1&clicked=")
180
+ print("------ here 22 " )
181
+ chat_history =[]
182
+ retriever_chain = get_context_retriever_chain(vs)
183
+ conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
184
+
185
+ response = conversation_rag_chain.invoke({
186
+ "chat_history": chat_history,
187
+ "input": user_input
188
+ })
189
+
190
+ return response['answer']
191
+
192
+ def simple(text:str):
193
+ return text +" hhhmmm "
194
+
195
+ app = gr.Interface(
196
+ fn=get_response,
197
+ #fn=simple,
198
+ inputs=["text"],
199
+ outputs="text",
200
+ title="Chat with Websites",
201
+ description="Type your message and chat with websites.",
202
+ #allow_flagging=False
203
+ )
204
+
205
+ app.launch(debug=True, share=True)#wie registriere ich mich bei bofrost? Was kosten Linguine