DataChat / app.py
Jkalonji's picture
adding text source
2f6c900 verified
raw
history blame
No virus
3.65 kB
import os
import re
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
from transformers import pipeline
import gradio as gr
HUGGING_FACE_TOKEN = os.environ["HUGGING_FACE_TOKEN"]
# Vous pouvez choisir parmi les nombreux midèles disponibles sur HugginFace (https://huggingface.co/models)
model_name = "llmware/industry-bert-insurance-v0.1"
def remove_special_characters(string):
return re.sub(r"\n", " ", string)
def RAG_Langchain(query):
embeddings = SentenceTransformerEmbeddings(model_name=model_name)
repo_id = "llmware/bling-sheared-llama-1.3b-0.1"
loader = DirectoryLoader('', glob="**/*.txt", show_progress=True, loader_cls=PyPDFLoader)
documents = loader.load()
# La taille des chunks est un paramètre important pour la qualité de l'information retrouvée. Il existe plusieurs méthodes
# pour en choisir la valeur.
# L'overlap correspond au nombre de caractères partagés entre un chunk et le chunk suivant
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
chunk = texts[0]
chunk.page_content = remove_special_characters(chunk.page_content)
#Data Preparation
for chunks in texts:
chunks.page_content = remove_special_characters(chunks.page_content)
# On charge tous les documents dans la base de données vectorielle, pour les utiliser ensuite
vector_stores=Chroma.from_documents(texts, embeddings, collection_metadata = {"hnsw:space": "cosine"}, persist_directory="stores/insurance_cosine")
#Retrieval
load_vector_store=Chroma(persist_directory="stores/insurance_cosine", embedding_function=embeddings)
#On prend pour l'instant k=1, on verra plus tard comment sélectionner les résultats de contexte
docs = load_vector_store.similarity_search_with_score(query=query, k=1)
results = {"Score":[],"Content":[],"Metadata":[]};
for i in docs:
doc, score = i
#print({"Score":score, "Content":doc.page_content, "Metadata":doc.metadata})
results['Score'].append(score)
results['Content'].append(doc.page_content)
results['Metadata'].append(doc.metadata)
context = results['Content']
return results
def generateResponseBasedOnContext(model_name, context_string, query):
question_answerer = pipeline("question-answering", model=model_name)
context_prompt = "You are a sports expert. Answer the user's question by using following context: "
context = context_prompt + context_string
print("context : ", context)
result = question_answerer(question=query, context=context)
return result['answer']
def gradio_adapted_RAG(model_name, query):
context = str(RAG_Langchain(query)['Content'])
generated_answer = generateResponseBasedOnContext(str(model_name),
context,
query)
return generated_answer
dropdown = gr.Dropdown(choices=["distilbert-base-uncased-distilled-squad",
"impira/layoutlm-document-qa",
"impira/layoutlm-invoices"], label="Choose a model")
iface = gr.Interface(fn=gradio_adapted_RAG, inputs=[dropdown, "text"], outputs="text")
iface.launch()