JoThanos commited on
Commit
77b4cf5
1 Parent(s): 201901e

Initialize RAG

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ from textwrap import fill
5
+ from IPython.display import Markdown, display
6
+
7
+ from langchain.prompts.chat import (
8
+ ChatPromptTemplate,
9
+ HumanMessagePromptTemplate,
10
+ SystemMessagePromptTemplate,
11
+ )
12
+
13
+ from langchain import PromptTemplate
14
+ from langchain import HuggingFacePipeline
15
+
16
+ from langchain.vectorstores import Chroma
17
+ from langchain.schema import AIMessage, HumanMessage
18
+ from langchain.memory import ConversationBufferMemory
19
+ from langchain.embeddings import HuggingFaceEmbeddings
20
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
21
+ from langchain.document_loaders import UnstructuredMarkdownLoader, UnstructuredURLLoader
22
+ from langchain.chains import LLMChain, SimpleSequentialChain, RetrievalQA, ConversationalRetrievalChain
23
+
24
+ from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
25
+
26
+ import warnings
27
+ warnings.filterwarnings('ignore')
28
+
29
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
30
+ EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
31
+
32
+ quantization_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_compute_dtype=torch.float16,
35
+ bnb_4bit_quant_type="nf4",
36
+ bnb_4bit_use_double_quant=True,
37
+ )
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
40
+ tokenizer.pad_token = tokenizer.eos_token
41
+
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ MODEL_NAME, torch_dtype=torch.float16,
44
+ trust_remote_code=True,
45
+ device_map="auto",
46
+ quantization_config=quantization_config
47
+ )
48
+
49
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
50
+ generation_config.max_new_tokens = 1024
51
+ generation_config.temperature = 0.0001
52
+ generation_config.top_p = 0.95
53
+ generation_config.do_sample = True
54
+ generation_config.repetition_penalty = 1.15
55
+
56
+ llm = HuggingFacePipeline(pipeline=pipeline)
57
+ embeddings = HuggingFaceEmbeddings(model_name = EMBEDDING_MODEL)
58
+
59
+ urls = [
60
+ "https://www.boe.es/diario_boe/txt.php?id=BOE-A-2024-9523"
61
+ ]
62
+
63
+
64
+ loader = UnstructuredURLLoader(urls=urls)
65
+ documents = loader.load()
66
+
67
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
68
+ texts_chunks = text_splitter.split_documents(documents)
69
+
70
+ db = Chroma.from_documents(texts_chunks, embeddings, persist_directory="db")
71
+
72
+ template = """Act as an lawyer assistant manager expert. Use the following information to answer the question at the end.
73
+ 'You must always answer in Spanish' If you do not know the answer reply with 'I am sorry, I dont have enough information'.
74
+ Chat History
75
+ {chat_history}
76
+ Follow Up Input: {question}
77
+ Standalone question:
78
+ """
79
+
80
+ CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(template)
81
+
82
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
83
+
84
+ llm_chain = ConversationalRetrievalChain.from_llm(
85
+ llm=llm,
86
+ retriever=db.as_retriever(search_kwargs={"k": 2}),
87
+ memory=memory,
88
+ condense_question_prompt=CUSTOM_QUESTION_PROMPT,
89
+ )
90
+
91
+ def querying(query, history):
92
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=False)
93
+
94
+ qa_chain = ConversationalRetrievalChain.from_llm(
95
+ llm=llm,
96
+ retriever=db.as_retriever(search_kwargs={"k": 2}),
97
+ memory=memory,
98
+ condense_question_prompt=CUSTOM_QUESTION_PROMPT,
99
+ )
100
+
101
+ result = qa_chain({"question": query})
102
+ return result["answer"].strip()
103
+
104
+ iface = gr.ChatInterface(
105
+ fn = querying,
106
+ chatbot=gr.Chatbot(height=600),
107
+ textbox=gr.Textbox(placeholder="Cuantos segmentos hay y en que consisten?", container=False, scale=7),
108
+ title="LawyerBot",
109
+ theme="soft",
110
+ examples=["¿Cuantos segmentos hay?",
111
+ "¿Que importe del bono digital corresponde a cada uno de los 5 segmentos?",
112
+ "¿Cuál es el importe de la ayuda para el segmento III en canto a dispositivo hardware?",
113
+ "Si tengo una microempresa de 2 empleado, ¿qué importe del bono digital me corresponde?",
114
+ "¿Qué nuevos segmentos de beneficiarios se han introducido?"],
115
+ cache_examples=True,
116
+ retry_btn="Repetir",
117
+ undo_btn="Deshacer",
118
+ clear_btn="Borrar",
119
+ submit_btn="Enviar"
120
+ )
121
+
122
+ iface.launch(share=True)