paloma99 commited on
Commit
1fecc0a
1 Parent(s): 922dead

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -29
app.py CHANGED
@@ -54,14 +54,12 @@ image_gradio_app = gr.Interface(
54
  theme=theme
55
  )
56
 
57
- # Cell 2: Chatbot Model
58
-
59
- loader = PyPDFDirectoryLoader('pdfs')
60
  data=loader.load()
61
  # split documents
62
  text_splitter = RecursiveCharacterTextSplitter(
63
- chunk_size=500,
64
- chunk_overlap=70,
65
  length_function=len
66
  )
67
  docs = text_splitter.split_documents(data)
@@ -78,56 +76,74 @@ vectordb = Chroma.from_documents(
78
  persist_directory=persist_directory
79
  )
80
  # define retriever
81
- retriever = vectordb.as_retriever(search_type="mmr")
 
 
 
 
 
 
 
 
82
  template = """
83
- Your name is Greta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish /
84
- Use the following pieces of context to answer the question if the question is related with recycling /
85
- No more than two chunks of context /
86
- Answer in the same language of the question /
87
- Always say "thanks for asking!" at the end of the answer /
88
- If the context is not relevant, please answer the question by using your own knowledge about the topic.
89
-
90
- context: {context}
91
- question: {question}
92
  """
93
 
94
  # Create the chat prompt templates
95
- system_prompt = SystemMessagePromptTemplate.from_template(template)
96
  qa_prompt = ChatPromptTemplate(
97
- messages=[
98
- system_prompt,
99
- MessagesPlaceholder(variable_name="chat_history"),
100
- HumanMessagePromptTemplate.from_template("{question}")
101
- ]
102
  )
103
  llm = HuggingFaceHub(
104
  repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
105
  task="text-generation",
106
  model_kwargs={
107
- "max_new_tokens": 1024,
108
  "top_k": 30,
109
  "temperature": 0.1,
110
- "repetition_penalty": 1.03,
111
  },
112
  )
113
 
114
- memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
115
-
116
  qa_chain = ConversationalRetrievalChain.from_llm(
117
  llm = llm,
118
- memory = memory,
119
  retriever = retriever,
120
  verbose = True,
121
  combine_docs_chain_kwargs={'prompt': qa_prompt},
122
  get_chat_history = lambda h : h,
123
  rephrase_question = False,
124
- output_key = 'answer'
125
  )
126
 
127
  def chat_interface(question,history):
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- result = qa_chain.invoke({"question": question})
130
- return result['answer'] # If the result is a string, return it directly
 
 
131
 
132
 
133
  chatbot_gradio_app = gr.ChatInterface(
 
54
  theme=theme
55
  )
56
 
57
+ loader = WebBaseLoader(["https://www.epa.gov/recycle/frequent-questions-recycling", "https://www.whitehorsedc.gov.uk/vale-of-white-horse-district-council/recycling-rubbish-and-waste/lets-get-real-about-recycling/", "https://www.teimas.com/blog/13-preguntas-y-respuestas-sobre-la-ley-de-residuos-07-2022", "https://www.molok.com/es/blog/gestion-de-residuos-solidos-urbanos-rsu-10-dudas-comunes"])
 
 
58
  data=loader.load()
59
  # split documents
60
  text_splitter = RecursiveCharacterTextSplitter(
61
+ chunk_size=1024,
62
+ chunk_overlap=150,
63
  length_function=len
64
  )
65
  docs = text_splitter.split_documents(data)
 
76
  persist_directory=persist_directory
77
  )
78
  # define retriever
79
+ retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
80
+
81
+ class FinalAnswer(BaseModel):
82
+ question: str = Field(description="the original question")
83
+ answer: str = Field(description="the extracted answer")
84
+
85
+ # Assuming you have a parser for the FinalAnswer class
86
+ parser = PydanticOutputParser(pydantic_object=FinalAnswer)
87
+
88
  template = """
89
+ Your name is AngryGreta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish /
90
+ Use the following pieces of context to answer the question /
91
+ If the question is English answer in English /
92
+ If the question is Spanish answer in Spanish /
93
+ Do not mention the word context when you answer a question /
94
+ Answer the question fully and provide as much relevant detail as possible. Do not cut your response short /
95
+ Context: {context}
96
+ User: {question}
97
+ {format_instructions}
98
  """
99
 
100
  # Create the chat prompt templates
101
+ sys_prompt = SystemMessagePromptTemplate.from_template(template)
102
  qa_prompt = ChatPromptTemplate(
103
+ messages=[
104
+ sys_prompt,
105
+ HumanMessagePromptTemplate.from_template("{question}")],
106
+ partial_variables={"format_instructions": parser.get_format_instructions()}
 
107
  )
108
  llm = HuggingFaceHub(
109
  repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
110
  task="text-generation",
111
  model_kwargs={
112
+ "max_new_tokens": 2000,
113
  "top_k": 30,
114
  "temperature": 0.1,
115
+ "repetition_penalty": 1.03
116
  },
117
  )
118
 
 
 
119
  qa_chain = ConversationalRetrievalChain.from_llm(
120
  llm = llm,
121
+ memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='output'),
122
  retriever = retriever,
123
  verbose = True,
124
  combine_docs_chain_kwargs={'prompt': qa_prompt},
125
  get_chat_history = lambda h : h,
126
  rephrase_question = False,
127
+ output_key = 'output',
128
  )
129
 
130
  def chat_interface(question,history):
131
+ result = qa_chain.invoke({'question': question})
132
+ output_string = result['output']
133
+
134
+ # Find the index of the last occurrence of "answer": in the string
135
+ answer_index = output_string.rfind('"answer":')
136
+
137
+ # Extract the substring starting from the "answer": index
138
+ answer_part = output_string[answer_index + len('"answer":'):].strip()
139
+
140
+ # Find the next occurrence of a double quote to get the start of the answer value
141
+ quote_index = answer_part.find('"')
142
 
143
+ # Extract the answer value between double quotes
144
+ answer_value = answer_part[quote_index + 1:answer_part.find('"', quote_index + 1)]
145
+
146
+ return answer_value
147
 
148
 
149
  chatbot_gradio_app = gr.ChatInterface(