elia-waefler commited on
Commit
3bd2065
1 Parent(s): 2c3613b

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +68 -13
  2. app.py +238 -0
  3. html_templates.py +44 -0
  4. requirements.txt +10 -0
README.md CHANGED
@@ -1,13 +1,68 @@
1
- ---
2
- title: Classify ASH
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: streamlit
7
- sdk_version: 1.33.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Project Name
2
+ This repository contains the trained model and scripts for [brief description of what the model does or its purpose]. This model is designed to [describe applications or what the model can be used for, like generating text, classifying images, etc.].
3
+
4
+ Model Description
5
+ [Provide a detailed description of the model, including the architecture, training data, and any significant features that highlight its uniqueness or capabilities.]
6
+
7
+ Features
8
+ Feature 1: [Description]
9
+ Feature 2: [Description]
10
+ Feature 3: [Description]
11
+ Installation
12
+ To use this model, you first need to install the required packages. Run the following command:
13
+
14
+ bash
15
+ Copy code
16
+ pip install -r requirements.txt
17
+ Usage
18
+ Here's how to use the model in your project:
19
+
20
+ python
21
+ Copy code
22
+ from transformers import AutoModel, AutoTokenizer
23
+
24
+ model_name = "your-huggingface-model-identifier"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModel.from_pretrained(model_name)
27
+
28
+ def predict(text):
29
+ inputs = tokenizer(text, return_tensors="pt")
30
+ with torch.no_grad():
31
+ logits = model(**inputs).logits
32
+ predicted_class_id = logits.argmax().item()
33
+ return model.config.id2label[predicted_class_id]
34
+
35
+ # Example usage
36
+ text = "Your example text here"
37
+ print(predict(text))
38
+ Performance
39
+ Discuss the performance metrics, benchmarks, or comparisons here, showing how the model performs in various scenarios.
40
+
41
+ Contributing
42
+ We welcome contributions to improve the model or scripts. Please follow these steps to contribute:
43
+
44
+ Fork the repository.
45
+ Create a new branch (git checkout -b feature-branch).
46
+ Commit your changes (git commit -am 'Add some feature').
47
+ Push to the branch (git push origin feature-branch).
48
+ Open a new Pull Request.
49
+ License
50
+ This project is licensed under the [choose a license] - see the LICENSE file for details.
51
+
52
+ Citation
53
+ If you use this model in your research, please cite it as follows:
54
+
55
+ bibtex
56
+ Copy code
57
+ @inproceedings{author2023model,
58
+ title={Title of Your Model},
59
+ author={Author Names},
60
+ booktitle={Where it was published},
61
+ year={2023}
62
+ }
63
+ Acknowledgments
64
+ Mention any advisors, financial supporters, or data providers.
65
+ Any other recognition or credits.
66
+ Contact
67
+ For issues, questions, or collaborations, you can contact [email contact] or create an issue in this repository.
68
+
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ stores vectores in session state, or locally.
3
+ loading from local does not yet work.
4
+ load funciton must recieve the uploaded file from fileuploader.
5
+ """
6
+
7
+ import time
8
+ from datetime import datetime
9
+ import openai
10
+ import tiktoken
11
+ import streamlit as st
12
+ from PyPDF2 import PdfReader
13
+ from langchain.text_splitter import CharacterTextSplitter
14
+ from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
15
+ from langchain.vectorstores import FAISS
16
+ from langchain.chat_models import ChatOpenAI
17
+ from langchain.memory import ConversationBufferMemory
18
+ from langchain.chains import ConversationalRetrievalChain
19
+ from html_templates import css, bot_template, user_template
20
+ from langchain.llms import HuggingFaceHub
21
+ import os
22
+ import numpy as np
23
+
24
+
25
+ def merge_faiss_indices(index1, index2):
26
+ """
27
+ Merge two FAISS indices into a new index, assuming both are of the same type and dimensionality.
28
+
29
+ Args:
30
+ index1 (faiss.Index): The first FAISS index.
31
+ index2 (faiss.Index): The second FAISS index.
32
+
33
+ Returns:
34
+ faiss.Index: A new FAISS index containing all vectors from index1 and index2.
35
+ """
36
+
37
+ # Check if both indices are the same type
38
+ if type(index1) != type(index2):
39
+ raise ValueError("Indices are of different types")
40
+
41
+ # Check dimensionality
42
+ if index1.d != index2.d:
43
+ raise ValueError("Indices have different dimensionality")
44
+
45
+ # Determine type of indices
46
+ if isinstance(index1, FAISS.IndexFlatL2):
47
+ # Handle simple flat indices
48
+ d = index1.d
49
+ # Extract vectors from both indices
50
+ xb1 = FAISS.rev_swig_ptr(index1.xb.data(), index1.ntotal * d)
51
+ xb2 = FAISS.rev_swig_ptr(index2.xb.data(), index2.ntotal * d)
52
+
53
+ # Combine vectors
54
+ xb_combined = np.vstack((xb1, xb2))
55
+
56
+ # Create a new index and add combined vectors
57
+ new_index = FAISS.IndexFlatL2(d)
58
+ new_index.add(xb_combined)
59
+ return new_index
60
+
61
+ elif isinstance(index1, FAISS.IndexIVFFlat):
62
+ # Handle quantized indices (IndexIVFFlat)
63
+ d = index1.d
64
+ nlist = index1.nlist
65
+ quantizer = FAISS.IndexFlatL2(d) # Re-create the appropriate quantizer
66
+
67
+ # Create a new index with the same configuration
68
+ new_index = FAISS.IndexIVFFlat(quantizer, d, nlist, FAISS.METRIC_L2)
69
+
70
+ # If the indices are already trained, you can directly add the vectors
71
+ # Otherwise, you may need to train new_index using a representative subset of vectors
72
+ vecs1 = FAISS.rev_swig_ptr(index1.xb.data(), index1.ntotal * d)
73
+ vecs2 = FAISS.rev_swig_ptr(index2.xb.data(), index2.ntotal * d)
74
+ new_index.add(vecs1)
75
+ new_index.add(vecs2)
76
+ return new_index
77
+
78
+ else:
79
+ raise TypeError("Index type not supported for merging in this function")
80
+
81
+
82
+ def get_pdf_text(pdf_docs):
83
+ text = ""
84
+ for pdf in pdf_docs:
85
+ pdf_reader = PdfReader(pdf)
86
+ for page in pdf_reader.pages:
87
+ text += page.extract_text()
88
+ return text
89
+
90
+
91
+ def get_text_chunks(text):
92
+ text_splitter = CharacterTextSplitter(
93
+ separator="\n",
94
+ chunk_size=1000,
95
+ chunk_overlap=200,
96
+ length_function=len
97
+ )
98
+ chunks = text_splitter.split_text(text)
99
+ return chunks
100
+
101
+
102
+ def get_faiss_vectorstore(text_chunks):
103
+ if st.session_state.openai:
104
+ my_embeddings = OpenAIEmbeddings()
105
+ else:
106
+ my_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
107
+ vectorstore = FAISS.from_texts(texts=text_chunks, embedding=my_embeddings)
108
+ return vectorstore
109
+
110
+
111
+ def get_conversation_chain(vectorstore):
112
+ if st.session_state.openai:
113
+ llm = ChatOpenAI()
114
+ else:
115
+ llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature": 0.5, "max_length": 512})
116
+
117
+ memory = ConversationBufferMemory(
118
+ memory_key='chat_history', return_messages=True)
119
+ conversation_chain = ConversationalRetrievalChain.from_llm(
120
+ llm=llm,
121
+ retriever=vectorstore.as_retriever(),
122
+ memory=memory
123
+ )
124
+ return conversation_chain
125
+
126
+
127
+ def handle_userinput(user_question):
128
+ response = st.session_state.conversation({'question': user_question})
129
+ st.session_state.chat_history = response['chat_history']
130
+
131
+ for i, message in enumerate(st.session_state.chat_history):
132
+ # Display user message
133
+ if i % 2 == 0:
134
+ st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
135
+ else:
136
+ print(message)
137
+ # Display AI response
138
+ st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
139
+ # Display source document information if available in the message
140
+ if hasattr(message, 'source') and message.source:
141
+ st.write(f"Source Document: {message.source}", unsafe_allow_html=True)
142
+
143
+ def set_global_variables():
144
+ global BASE_URL
145
+ BASE_URL = "https://api.vectara.io/v1"
146
+ global OPENAI_API_KEY
147
+ OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
148
+ global OPENAI_ORG_ID
149
+ OPENAI_ORG_ID = os.environ["OPENAI_ORG_ID"]
150
+ global PINECONE_API_KEY
151
+ PINECONE_API_KEY = os.environ["PINECONE_API_KEY_LCBIM"]
152
+ global HUGGINGFACEHUB_API_TOKEN
153
+ HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
154
+ global VECTARA_API_KEY
155
+ VECTARA_API_KEY = os.environ["VECTARA_API_KEY"]
156
+ global VECTARA_CUSTOMER_ID
157
+ VECTARA_CUSTOMER_ID = os.environ["VECTARA_CUSTOMER_ID"]
158
+ global headers
159
+ headers = {"Authorization": f"Bearer {VECTARA_API_KEY}", "Content-Type": "application/json"}
160
+
161
+
162
+ def main():
163
+ set_global_variables()
164
+ st.set_page_config(page_title="Anna Seiler Haus KI-Assistent", page_icon=":hospital:")
165
+ st.write(css, unsafe_allow_html=True)
166
+ if "conversation" not in st.session_state:
167
+ st.session_state.conversation = None
168
+ if "chat_history" not in st.session_state:
169
+ st.session_state.chat_history = None
170
+ if "page" not in st.session_state:
171
+ st.session_state.page = "home"
172
+ if "openai" not in st.session_state:
173
+ st.session_state.openai = True
174
+ if "login" not in st.session_state:
175
+ st.session_state.login = False
176
+
177
+ st.header("Anna Seiler Haus KI-Assistent ASH :hospital:")
178
+ if st.text_input("ASK_ASH_PASSWORD: ", type="password") == ASK_ASH_PASSWORD:
179
+ if True:
180
+ OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
181
+ # ASK_ASH_PASSWORD = False
182
+ OPENAI_API_KEY = False
183
+ OPENAI_ORG_ID = os.environ["OPENAI_ORG_ID"]
184
+ PINECONE_API_KEY = os.environ["PINECONE_API_KEY_LCBIM"]
185
+ HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
186
+ VECTARA_CORPUS_ID = "3"
187
+ VECTARA_API_KEY = os.environ["VECTARA_API_KEY"]
188
+ VECTARA_CUSTOMER_ID = os.environ["VECTARA_CUSTOMER_ID"]
189
+
190
+ user_question = st.text_input("Ask a question about your documents:")
191
+
192
+ st.session_state.openai = st.toggle(label="use openai?")
193
+ # if st.session_state.openai:
194
+ # st.session_state.openai_key = st.text_input("openai api key", type="password")
195
+ # OPENAI_API_KEY = st.session_state.openai_key
196
+
197
+ if user_question:
198
+ handle_userinput(user_question)
199
+
200
+ with st.sidebar:
201
+ st.subheader("Your documents")
202
+ pdf_docs = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
203
+ if st.button("Process"):
204
+ with st.spinner("Processing"):
205
+ raw_text = get_pdf_text(pdf_docs)
206
+ text_chunks = get_text_chunks(raw_text)
207
+ vec = get_faiss_vectorstore(text_chunks)
208
+ st.session_state.vectorstore = vec
209
+ st.session_state.conversation = get_conversation_chain(vec)
210
+
211
+ # Save and Load Embeddings
212
+ if st.button("Save Embeddings"):
213
+ if "vectorstore" in st.session_state:
214
+ st.session_state.vectorstore.save_local(str(datetime.now().strftime("%Y%m%d%H%M%S")) + "faiss_index")
215
+ st.sidebar.success("saved")
216
+ else:
217
+ st.sidebar.warning("No embeddings to save. Please process documents first.")
218
+
219
+ if st.button("Load Embeddings"):
220
+ if "vectorstore" in st.session_state:
221
+ new_db = FAISS.load_local()
222
+ if new_db is not None: # Check if this is working
223
+ combined_db = merge_faiss_indices(new_db, st.session_state.vectorstore)
224
+ st.session_state.vectorstore = combined_db
225
+ st.session_state.conversation = get_conversation_chain(combined_db)
226
+ else:
227
+ st.sidebar.warning("Couldn't load embeddings")
228
+ else:
229
+ new_db = FAISS.load_local("faiss_index")
230
+ if new_db is not None: # Check if this is working
231
+ st.session_state.vectorstore = new_db
232
+ st.session_state.conversation = get_conversation_chain(new_db)
233
+
234
+
235
+ if __name__ == '__main__':
236
+ set_global_variables()
237
+ ASK_ASH_PASSWORD = os.environ["ASK_ASH_PASSWORD"]
238
+ main()
html_templates.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ css = '''
2
+ <style>
3
+ .chat-message {
4
+ padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex
5
+ }
6
+ .chat-message.user {
7
+ background-color: #2b313e
8
+ }
9
+ .chat-message.bot {
10
+ background-color: #475063
11
+ }
12
+ .chat-message .avatar {
13
+ width: 20%;
14
+ }
15
+ .chat-message .avatar img {
16
+ max-width: 78px;
17
+ max-height: 78px;
18
+ border-radius: 50%;
19
+ object-fit: cover;
20
+ }
21
+ .chat-message .message {
22
+ width: 80%;
23
+ padding: 0 1.5rem;
24
+ color: #fff;
25
+ }
26
+ '''
27
+
28
+ bot_template = '''
29
+ <div class="chat-message bot">
30
+ <div class="avatar">
31
+ <img src="https://www.insel.ch/_ari/115280/49841742b8afbc44928918244fb4c6f9b487d5b3/9f6e35f65cbd0d6c47c145f90b1d5a297eb50bcd/1400/0/og/20230704-Anna-Seiler-Haus-009-screen.jpg.webp" style="max-height: 78px; max-width: 78px; border-radius: 50%; object-fit: cover;">
32
+ </div>
33
+ <div class="message">{{MSG}}</div>
34
+ </div>
35
+ '''
36
+
37
+ user_template = '''
38
+ <div class="chat-message user">
39
+ <div class="avatar">
40
+ <img src="https://media.licdn.com/dms/image/C4D03AQHi5rJfheyUtQ/profile-displayphoto-shrink_800_800/0/1638174649461?e=2147483647&v=beta&t=KOsttcLGIwB9pBEVfceHj-ckv_zPHs-2COyrp7aYR-k">
41
+ </div>
42
+ <div class="message">{{MSG}}</div>
43
+ </div>
44
+ '''
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit~=1.33.0
2
+ bcrypt~=4.1.2
3
+ psycopg2-binary~=2.9.9
4
+ openai~=1.23.2
5
+ pypdf2~=3.0.1
6
+ langchain~=0.1.16
7
+ tiktoken~=0.6.0
8
+ numpy~=1.26.4
9
+ requests~=2.31.0
10
+ faiss-cpu