Spaces:
Running
Running
File size: 8,136 Bytes
3ec9224 5be8df6 ebc9208 ab26ada 93068c0 ebc9208 93068c0 ebc9208 93068c0 ab26ada 55c700d ebc9208 589aec7 ebc9208 93068c0 ebc9208 93068c0 ebc9208 ab26ada ebc9208 93068c0 ebc9208 93068c0 ebc9208 55c700d ebc9208 55c700d 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 70f7419 ebc9208 70f7419 ebc9208 ab0e4f2 ebc9208 7f0656e ebc9208 7f0656e ebc9208 7f0656e 55c700d 7f0656e 2ed44bd 7f0656e ebc9208 7f0656e fabd344 7f0656e 12b47b9 7f0656e ebc9208 7f0656e 55c700d 7f0656e 12b47b9 ebc9208 55c700d 93068c0 12b47b9 93068c0 7060329 95812af ab0e4f2 ebc9208 ab26ada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
list_llm = [
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1",
"google/gemma-7b-it", "google/gemma-2b-it",
"HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "tiiuae/falcon-7b-instruct"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(pages)
def create_db(splits, collection_name, db_type):
if db_type == 0: # Multilingual MiniLM
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
else: # Italian BERT
embedding = HuggingFaceEmbeddings(model_name="dbmdz/bert-base-italian-xxl-uncased")
new_client = chromadb.EphemeralClient()
return Chroma.from_documents(documents=splits, embedding=embedding, client=new_client, collection_name=collection_name)
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.5, desc="Initializing HF Hub...")
llm = HuggingFaceEndpoint(
repo_id=llm_model,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = unidecode(collection_name.replace(" ", "-"))
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
if len(collection_name) < 3:
collection_name += 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
return collection_name
def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
list_file_path = [x.name for x in list_file_obj if x is not None]
collection_name = create_collection_name(list_file_path[0])
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
vector_db = create_db(doc_splits, collection_name, db_type)
return vector_db, collection_name, "Completed!"
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "Completed!"
def format_chat_history(message, chat_history):
return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"].split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
sources = [(source.page_content.strip(), source.metadata["page"] + 1) for source in response_sources[:5]]
new_history = history + [(message, response_answer)]
# Ensure we always return 5 sources and 5 pages
source_texts = [source[0] for source in sources] + [''] * (5 - len(sources))
source_pages = [source[1] for source in sources] + [0] * (5 - len(sources))
return (qa_chain, gr.update(value=""), new_history,
*source_texts[:5], # Unpack exactly 5 source texts
*source_pages[:5]) # Unpack exactly 5 source pages
def clear_conversation():
return gr.update(value=""), [], "", "", "", "", "", 0, 0, 0, 0, 0
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown("# Creatore di Chatbot basato su PDF")
with gr.Tab("Passo 1 - Carica PDF"):
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carica i tuoi documenti PDF")
with gr.Tab("Passo 2 - Elabora Documenti"):
db_type = gr.Radio(["ChromaDB (Multilingual MiniLM Embedding)", "ChromaDB (Italian BERT Embedding)"], label="Tipo di database vettoriale", value="ChromaDB (Multilingual MiniLM Embedding)", type="index")
with gr.Accordion("Opzioni Avanzate - Divisione del testo del documento", open=False):
slider_chunk_size = gr.Slider(100, 1000, 1000, step=20, label="Dimensione del chunk")
slider_chunk_overlap = gr.Slider(10, 200, 100, step=10, label="Sovrapposizione del chunk")
db_progress = gr.Textbox(label="Inizializzazione del database vettoriale", value="Nessuna")
db_btn = gr.Button("Genera database vettoriale")
with gr.Tab("Passo 3 - Inizializza catena QA"):
llm_btn = gr.Radio(list_llm_simple, label="Modelli LLM", value=list_llm_simple[4], type="index")
with gr.Accordion("Opzioni avanzate - Modello LLM", open=False):
slider_temperature = gr.Slider(0.01, 1.0, 0.3, step=0.1, label="Temperatura")
slider_maxtokens = gr.Slider(224, 4096, 1024, step=32, label="Token massimi")
slider_topk = gr.Slider(1, 10, 3, step=1, label="Campioni top-k")
language_btn = gr.Radio(["Italiano", "Inglese"], label="Lingua", value="Italiano", type="index")
llm_progress = gr.Textbox(value="Nessuna", label="Inizializzazione catena QA")
qachain_btn = gr.Button("Inizializza catena di Domanda e Risposta")
with gr.Tab("Passo 4 - Chatbot"):
chatbot = gr.Chatbot(height=300)
with gr.Accordion("Opzioni avanzate - Riferimenti ai documenti", open=False):
doc_sources = [gr.Textbox(label=f"Riferimento {i+1}", lines=2, container=True, scale=20) for i in range(5)]
source_pages = [gr.Number(label="Pagina", scale=1) for _ in range(5)]
msg = gr.Textbox(placeholder="Inserisci il messaggio (es. 'Di cosa tratta questo documento?')", container=True)
submit_btn = gr.Button("Invia messaggio")
clear_btn = gr.Button("Cancella conversazione")
db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type], outputs=[vector_db, collection_name, db_progress])
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)
msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)
clear_btn.click(clear_conversation, inputs=[], outputs=[chatbot] + doc_sources + source_pages)
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo() |