DHEIVER commited on
Commit
2073925
·
verified ·
1 Parent(s): 7e34d60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -170
app.py CHANGED
@@ -7,9 +7,7 @@ from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.llms import HuggingFacePipeline
10
- from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
- from langchain_community.llms import HuggingFaceEndpoint
13
 
14
  from pathlib import Path
15
  import chromadb
@@ -18,22 +16,25 @@ from unidecode import unidecode
18
  from transformers import AutoTokenizer, pipeline
19
  import transformers
20
  import torch
21
- import tqdm
22
- import accelerate
23
  import re
24
 
25
- # Lista de modelos gratuitos que não exigem chave de API
26
  list_llm = [
27
- "mistralai/Mistral-7B-Instruct-v0.2",
28
- "mistralai/Mistral-7B-Instruct-v0.1",
29
- "google/flan-t5-xxl",
30
- "HuggingFaceH4/zephyr-7b-beta",
31
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
32
- "microsoft/phi-2"
 
 
 
 
33
  ]
 
34
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
35
 
36
- # Função para carregar o documento PDF e dividir em partes
37
  def load_doc(list_file_path, chunk_size, chunk_overlap):
38
  loaders = [PyPDFLoader(x) for x in list_file_path]
39
  pages = []
@@ -43,29 +44,26 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
43
  chunk_size=chunk_size,
44
  chunk_overlap=chunk_overlap
45
  )
46
- doc_splits = text_splitter.split_documents(pages)
47
- return doc_splits
48
 
49
- # Função para criar o banco de dados vetorial
50
  def create_db(splits, collection_name):
51
  embedding = HuggingFaceEmbeddings()
52
- new_client = chromadb.EphemeralClient()
53
- vectordb = Chroma.from_documents(
54
  documents=splits,
55
  embedding=embedding,
56
- client=new_client,
57
- collection_name=collection_name,
58
  )
59
- return vectordb
60
 
61
- # Função para inicializar a cadeia de LLM
62
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
63
- progress(0.1, desc="Initializing HF tokenizer...")
64
 
65
- # Carregar o tokenizer e o pipeline do modelo
66
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
67
- progress(0.5, desc="Initializing HF pipeline...")
68
- pipeline_model = transformers.pipeline(
 
69
  "text-generation",
70
  model=llm_model,
71
  tokenizer=tokenizer,
@@ -74,166 +72,75 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
74
  max_new_tokens=max_tokens,
75
  do_sample=True,
76
  top_k=top_k,
77
- temperature=temperature,
78
  )
79
- llm = HuggingFacePipeline(pipeline=pipeline_model)
80
-
81
- progress(0.75, desc="Defining buffer memory...")
 
82
  memory = ConversationBufferMemory(
83
  memory_key="chat_history",
84
- output_key='answer',
85
  return_messages=True
86
  )
87
- retriever = vector_db.as_retriever()
88
-
89
- progress(0.8, desc="Defining retrieval chain...")
90
- qa_chain = ConversationalRetrievalChain.from_llm(
91
- llm,
92
- retriever=retriever,
93
- chain_type="stuff",
94
  memory=memory,
95
- return_source_documents=True,
96
- verbose=False,
97
  )
98
- progress(0.9, desc="Done!")
99
- return qa_chain
100
-
101
- # Função para gerar o nome da coleção do banco de dados vetorial
102
- def create_collection_name(filepath):
103
- collection_name = Path(filepath).stem
104
- collection_name = collection_name.replace(" ", "-")
105
- collection_name = unidecode(collection_name)
106
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
107
- collection_name = collection_name[:50]
108
- if len(collection_name) < 3:
109
- collection_name = collection_name + 'xyz'
110
- if not collection_name[0].isalnum():
111
- collection_name = 'A' + collection_name[1:]
112
- if not collection_name[-1].isalnum():
113
- collection_name = collection_name[:-1] + 'Z'
114
- return collection_name
115
-
116
- # Função para inicializar o banco de dados
117
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
118
- list_file_path = [x.name for x in list_file_obj if x is not None]
119
- progress(0.1, desc="Creating collection name...")
120
- collection_name = create_collection_name(list_file_path[0])
121
- progress(0.25, desc="Loading document...")
122
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
123
- progress(0.5, desc="Generating vector database...")
124
- vector_db = create_db(doc_splits, collection_name)
125
- progress(0.9, desc="Done!")
126
- return vector_db, collection_name, "Complete!"
127
-
128
- # Função para inicializar a cadeia de QA
129
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
130
- llm_name = list_llm[llm_option]
131
- print("llm_name: ", llm_name)
132
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
133
- return qa_chain, "Complete!"
134
-
135
- # Função para formatar o histórico de conversa
136
- def format_chat_history(message, chat_history):
137
- formatted_chat_history = []
138
- for user_message, bot_message in chat_history:
139
- formatted_chat_history.append(f"User: {user_message}")
140
- formatted_chat_history.append(f"Assistant: {bot_message}")
141
- return formatted_chat_history
142
 
143
- # Função para gerar a conversa
144
- def conversation(qa_chain, message, history):
145
- formatted_chat_history = format_chat_history(message, history)
146
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
147
- response_answer = response["answer"]
148
- if response_answer.find("Helpful Answer:") != -1:
149
- response_answer = response_answer.split("Helpful Answer:")[-1]
150
- response_sources = response["source_documents"]
151
- response_source1 = response_sources[0].page_content.strip()
152
- response_source2 = response_sources[1].page_content.strip()
153
- response_source3 = response_sources[2].page_content.strip()
154
- response_source1_page = response_sources[0].metadata["page"] + 1
155
- response_source2_page = response_sources[1].metadata["page"] + 1
156
- response_source3_page = response_sources[2].metadata["page"] + 1
157
- new_history = history + [(message, response_answer)]
158
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
159
-
160
- # Função principal para rodar a interface
161
  def demo():
162
- with gr.Blocks(theme="base") as demo:
163
  vector_db = gr.State()
164
  qa_chain = gr.State()
165
- collection_name = gr.State()
166
-
167
- gr.Markdown(
168
- """<center><h2>PDF-based chatbot</center></h2>
169
- <h3>Ask any questions about your PDF documents</h3>""")
170
- gr.Markdown(
171
- """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
172
- The user interface explicitely shows multiple steps to help understand the RAG workflow.
173
- This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
174
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
175
- """)
176
 
177
- with gr.Tab("Step 1 - Upload PDF"):
178
- with gr.Row():
179
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
180
 
181
- with gr.Tab("Step 2 - Process document"):
182
- with gr.Row():
183
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
184
- with gr.Accordion("Advanced options - Document text splitter", open=False):
185
- with gr.Row():
186
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
187
- with gr.Row():
188
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
189
- with gr.Row():
190
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
191
- with gr.Row():
192
- db_btn = gr.Button("Generate vector database")
193
 
194
- with gr.Tab("Step 3 - Initialize QA chain"):
195
- with gr.Row():
196
- llm_btn = gr.Radio(list_llm_simple, label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model")
197
- with gr.Accordion("Advanced options - LLM model", open=False):
198
- with gr.Row():
199
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
200
- with gr.Row():
201
- slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
202
- with gr.Row():
203
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
204
- with gr.Row():
205
- llm_progress = gr.Textbox(value="None", label="QA chain initialization")
206
- with gr.Row():
207
- qachain_btn = gr.Button("Initialize Question Answering chain")
208
-
209
- with gr.Tab("Step 4 - Chatbot"):
210
- chatbot = gr.Chatbot(height=300)
211
- with gr.Accordion("Advanced - Document references", open=False):
212
- with gr.Row():
213
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
214
- source1_page = gr.Number(label="Page", scale=1)
215
- with gr.Row():
216
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
217
- source2_page = gr.Number(label="Page", scale=1)
218
- with gr.Row():
219
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
220
- source3_page = gr.Number(label="Page", scale=1)
221
- with gr.Row():
222
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
223
- with gr.Row():
224
- submit_btn = gr.Button("Submit message")
225
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
226
 
227
- # Eventos de pré-processamento
228
- db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
229
- qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
230
-
231
- # Eventos do chatbot
232
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
233
- submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
234
- clear_btn.click(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
235
-
236
- demo.queue().launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  if __name__ == "__main__":
239
  demo()
 
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.llms import HuggingFacePipeline
 
10
  from langchain.memory import ConversationBufferMemory
 
11
 
12
  from pathlib import Path
13
  import chromadb
 
16
  from transformers import AutoTokenizer, pipeline
17
  import transformers
18
  import torch
 
 
19
  import re
20
 
21
+ # Lista de modelos 100% abertos e gratuitos
22
  list_llm = [
23
+ "google/flan-t5-xxl", # Modelo para tarefas text-to-text
24
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Modelo leve para diálogo
25
+ "microsoft/phi-2", # Modelo para raciocínio lógico
26
+ "facebook/opt-1.3b", # Modelo de geração de texto
27
+ "EleutherAI/gpt-neo-1.3B", # Versão open-source do GPT-3
28
+ "bigscience/bloom-1b7", # Modelo multilíngue
29
+ "RWKV/rwkv-4-169m-pile", # Modelo eficiente em RAM
30
+ "gpt2-medium", # Clássico modelo de GPT-2
31
+ "databricks/dolly-v2-3b", # Modelo para instruções
32
+ "mosaicml/mpt-7b-instruct" # Modelo para instruções
33
  ]
34
+
35
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
37
+ # Função para carregar documentos PDF
38
  def load_doc(list_file_path, chunk_size, chunk_overlap):
39
  loaders = [PyPDFLoader(x) for x in list_file_path]
40
  pages = []
 
44
  chunk_size=chunk_size,
45
  chunk_overlap=chunk_overlap
46
  )
47
+ return text_splitter.split_documents(pages)
 
48
 
49
+ # Função para criar banco de dados vetorial
50
  def create_db(splits, collection_name):
51
  embedding = HuggingFaceEmbeddings()
52
+ return Chroma.from_documents(
 
53
  documents=splits,
54
  embedding=embedding,
55
+ client=chromadb.EphemeralClient(),
56
+ collection_name=collection_name
57
  )
 
58
 
59
+ # Função para inicializar o modelo LLM
60
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
+ progress(0.1, desc="Carregando tokenizer...")
62
 
 
63
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
64
+
65
+ progress(0.4, desc="Inicializando pipeline...")
66
+ pipeline_obj = pipeline(
67
  "text-generation",
68
  model=llm_model,
69
  tokenizer=tokenizer,
 
72
  max_new_tokens=max_tokens,
73
  do_sample=True,
74
  top_k=top_k,
75
+ temperature=temperature
76
  )
77
+
78
+ llm = HuggingFacePipeline(pipeline=pipeline_obj)
79
+
80
+ progress(0.7, desc="Configurando memória...")
81
  memory = ConversationBufferMemory(
82
  memory_key="chat_history",
 
83
  return_messages=True
84
  )
85
+
86
+ progress(0.8, desc="Criando cadeia...")
87
+ return ConversationalRetrievalChain.from_llm(
88
+ llm=llm,
89
+ retriever=vector_db.as_retriever(),
 
 
90
  memory=memory,
91
+ chain_type="stuff",
92
+ return_source_documents=True
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # Interface Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def demo():
97
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
98
  vector_db = gr.State()
99
  qa_chain = gr.State()
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ gr.Markdown("## 🤖 Chatbot para PDFs com Modelos Gratuitos")
 
 
102
 
103
+ with gr.Tab("📤 Upload PDF"):
104
+ pdf_input = gr.Files(label="Selecione seus PDFs", file_types=[".pdf"])
 
 
 
 
 
 
 
 
 
 
105
 
106
+ with gr.Tab("⚙️ Processamento"):
107
+ chunk_size = gr.Slider(100, 1000, value=500, label="Tamanho dos Chunks")
108
+ chunk_overlap = gr.Slider(0, 200, value=50, label="Sobreposição")
109
+ process_btn = gr.Button("Processar PDFs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ with gr.Tab("🧠 Modelo"):
112
+ model_selector = gr.Dropdown(list_llm_simple, label="Selecione o Modelo", value=list_llm_simple[0])
113
+ temperature = gr.Slider(0, 1, value=0.7, label="Criatividade")
114
+ load_model_btn = gr.Button("Carregar Modelo")
115
+
116
+ with gr.Tab("💬 Chat"):
117
+ chatbot = gr.Chatbot(height=400)
118
+ msg = gr.Textbox(label="Sua mensagem")
119
+ clear_btn = gr.ClearButton([msg, chatbot])
120
+
121
+ # Eventos
122
+ process_btn.click(
123
+ lambda files, cs, co: create_db(load_doc([f.name for f in files], cs, co), "docs"),
124
+ inputs=[pdf_input, chunk_size, chunk_overlap],
125
+ outputs=vector_db
126
+ )
127
+
128
+ load_model_btn.click(
129
+ lambda model, temp: initialize_llmchain(list_llm[list_llm_simple.index(model)], temp, 512, 3, vector_db.value),
130
+ inputs=[model_selector, temperature],
131
+ outputs=qa_chain
132
+ )
133
+
134
+ def respond(message, chat_history):
135
+ result = qa_chain.value({"question": message, "chat_history": chat_history})
136
+ response = result["answer"]
137
+ sources = "\n".join([f"📄 Página {doc.metadata['page']+1}: {doc.page_content[:50]}..."
138
+ for doc in result["source_documents"][:2]])
139
+ return f"{response}\n\n🔍 Fontes:\n{sources}"
140
+
141
+ msg.submit(respond, [msg, chatbot], chatbot)
142
+
143
+ demo.launch()
144
 
145
  if __name__ == "__main__":
146
  demo()