rocioadlc commited on
Commit
eca7db4
1 Parent(s): dce00f5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import theme
5
+
6
+ theme = theme.Theme()
7
+
8
+ import os
9
+ import sys
10
+ sys.path.append('../..')
11
+
12
+ #langchain
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
14
+ from langchain.embeddings import HuggingFaceEmbeddings
15
+ from langchain.prompts import PromptTemplate
16
+ from langchain.chains import RetrievalQA
17
+ from langchain.prompts import ChatPromptTemplate
18
+ from langchain.schema import StrOutputParser
19
+ from langchain.schema.runnable import Runnable
20
+ from langchain.schema.runnable.config import RunnableConfig
21
+ from langchain.chains import (
22
+ LLMChain, ConversationalRetrievalChain)
23
+ from langchain.vectorstores import Chroma
24
+ from langchain.memory import ConversationBufferMemory
25
+ from langchain.chains import LLMChain
26
+ from langchain.prompts.prompt import PromptTemplate
27
+ from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
28
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
29
+ from langchain.document_loaders import PyPDFDirectoryLoader
30
+ from pydantic import BaseModel, Field
31
+ from langchain.output_parsers import PydanticOutputParser
32
+ from langchain_community.llms import HuggingFaceHub
33
+ from langchain_community.document_loaders import WebBaseLoader
34
+
35
+ from pydantic import BaseModel
36
+ import shutil
37
+
38
+
39
+
40
+ custom_title = "<span style='color: rgb(243, 239, 224);'>Green Greta</span>"
41
+
42
+
43
+ # Cell 1: Image Classification Model
44
+ image_pipeline = pipeline(task="image-classification", model="guillen/vit-basura-test1")
45
+
46
+ def predict_image(input_img):
47
+ predictions = image_pipeline(input_img)
48
+ return {p["label"]: p["score"] for p in predictions}
49
+
50
+ image_gradio_app = gr.Interface(
51
+ fn=predict_image,
52
+ inputs=gr.Image(label="Image", sources=['upload', 'webcam'], type="pil"),
53
+ outputs=[gr.Label(label="Result")],
54
+ title=custom_title,
55
+ theme=theme
56
+ )
57
+
58
+ loader = WebBaseLoader(["https://www.epa.gov/recycle/frequent-questions-recycling", "https://www.whitehorsedc.gov.uk/vale-of-white-horse-district-council/recycling-rubbish-and-waste/lets-get-real-about-recycling/", "https://www.teimas.com/blog/13-preguntas-y-respuestas-sobre-la-ley-de-residuos-07-2022", "https://www.molok.com/es/blog/gestion-de-residuos-solidos-urbanos-rsu-10-dudas-comunes"])
59
+ data=loader.load()
60
+ # split documents
61
+ text_splitter = RecursiveCharacterTextSplitter(
62
+ chunk_size=1024,
63
+ chunk_overlap=150,
64
+ length_function=len
65
+ )
66
+ docs = text_splitter.split_documents(data)
67
+ # define embedding
68
+ embeddings = HuggingFaceEmbeddings(model_name='thenlper/gte-small')
69
+ # create vector database from data
70
+ persist_directory = 'docs/chroma/'
71
+
72
+ # Remove old database files if any
73
+ shutil.rmtree(persist_directory, ignore_errors=True)
74
+ vectordb = Chroma.from_documents(
75
+ documents=docs,
76
+ embedding=embeddings,
77
+ persist_directory=persist_directory
78
+ )
79
+ # define retriever
80
+ retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
81
+
82
+ class FinalAnswer(BaseModel):
83
+ question: str = Field(description="the original question")
84
+ answer: str = Field(description="the extracted answer")
85
+
86
+ # Assuming you have a parser for the FinalAnswer class
87
+ parser = PydanticOutputParser(pydantic_object=FinalAnswer)
88
+
89
+ template = """
90
+ Your name is Greta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish /
91
+ Use the following pieces of context to answer the question /
92
+ If the question is English answer in English /
93
+ If the question is Spanish answer in Spanish /
94
+ Do not mention the word context when you answer a question /
95
+ Answer the question fully and provide as much relevant detail as possible. Do not cut your response short /
96
+ Context: {context}
97
+ User: {question}
98
+ {format_instructions}
99
+ """
100
+
101
+ # Create the chat prompt templates
102
+ sys_prompt = SystemMessagePromptTemplate.from_template(template)
103
+ qa_prompt = ChatPromptTemplate(
104
+ messages=[
105
+ sys_prompt,
106
+ HumanMessagePromptTemplate.from_template("{question}")],
107
+ partial_variables={"format_instructions": parser.get_format_instructions()}
108
+ )
109
+ llm = HuggingFaceHub(
110
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
111
+ task="text-generation",
112
+ model_kwargs={
113
+ "max_new_tokens": 2000,
114
+ "top_k": 30,
115
+ "temperature": 0.1,
116
+ "repetition_penalty": 1.03
117
+ },
118
+ )
119
+
120
+ qa_chain = ConversationalRetrievalChain.from_llm(
121
+ llm = llm,
122
+ memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='output'),
123
+ retriever = retriever,
124
+ verbose = True,
125
+ combine_docs_chain_kwargs={'prompt': qa_prompt},
126
+ get_chat_history = lambda h : h,
127
+ rephrase_question = False,
128
+ output_key = 'output',
129
+ )
130
+
131
+ def chat_interface(question,history):
132
+ result = qa_chain.invoke({'question': question})
133
+ output_string = result['output']
134
+
135
+ # Find the index of the last occurrence of "answer": in the string
136
+ answer_index = output_string.rfind('"answer":')
137
+
138
+ # Extract the substring starting from the "answer": index
139
+ answer_part = output_string[answer_index + len('"answer":'):].strip()
140
+
141
+ # Find the next occurrence of a double quote to get the start of the answer value
142
+ quote_index = answer_part.find('"')
143
+
144
+ # Extract the answer value between double quotes
145
+ answer_value = answer_part[quote_index + 1:answer_part.find('"', quote_index + 1)]
146
+
147
+ return answer_value
148
+
149
+
150
+ chatbot_gradio_app = gr.ChatInterface(
151
+ fn=chat_interface,
152
+ title=custom_title
153
+ )
154
+
155
+ # Combine both interfaces into a single app
156
+ app = gr.TabbedInterface(
157
+ [image_gradio_app, chatbot_gradio_app],
158
+ tab_names=["Green Greta Image Classification","Green Greta Chat"],
159
+ theme=theme
160
+ )
161
+
162
+ app.queue()
163
+ app.launch()