Manikandan-Alagu commited on
Commit
8ebebd8
1 Parent(s): f309879

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
+ from langchain.chains import create_retrieval_chain, create_history_aware_retriever
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain_chroma import Chroma
7
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores.utils import filter_complex_metadata
12
+
13
+ # Document Processor Class
14
+ class DocumentProcessor:
15
+ def __init__(self):
16
+ self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
17
+ self.vectorstore = None
18
+ self.retriever = None
19
+
20
+ def process_documents(self, directory_path):
21
+
22
+ all_splits = []
23
+ try:
24
+ loader = DirectoryLoader(directory_path, glob="*.pdf", loader_cls=PyPDFLoader)
25
+ data = loader.load()
26
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
27
+ all_splits += text_splitter.split_documents(data)
28
+ except Exception as e:
29
+ print(f"Error loading documents: {e}")
30
+ return
31
+
32
+ doc = filter_complex_metadata(all_splits)
33
+ self.vectorstore = Chroma.from_documents(documents=doc, embedding=self.embeddings)
34
+ self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
35
+
36
+ def get_retriever(self):
37
+ return self.retriever
38
+
39
+ # Model Handler Class
40
+ class ModelHandler:
41
+ def __init__(self):
42
+ script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
43
+ self.model_cache_dir = os.path.join(script_dir, "model_cache") # Cache in the script directory
44
+ self.llm = None
45
+
46
+ def load_model(self):
47
+ model_name = "HuggingFaceH4/zephyr-7b-beta"
48
+
49
+ if os.path.exists(self.model_cache_dir):
50
+ print("Loading model from cache...")
51
+ model = AutoModelForCausalLM.from_pretrained(self.model_cache_dir)
52
+ tokenizer = AutoTokenizer.from_pretrained(self.model_cache_dir)
53
+ else:
54
+ print("Downloading and caching model...")
55
+ model = AutoModelForCausalLM.from_pretrained(model_name)
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
57
+ os.makedirs(self.model_cache_dir, exist_ok=True)
58
+ model.save_pretrained(self.model_cache_dir) # Cache the model in the script directory
59
+ tokenizer.save_pretrained(self.model_cache_dir)
60
+
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+ tokenizer.padding_side = "right"
63
+
64
+ text_generation_pipeline = pipeline(
65
+ model=model,
66
+ tokenizer=tokenizer,
67
+ task="text-generation",
68
+ temperature=0.2,
69
+ do_sample=True,
70
+ repetition_penalty=1.1,
71
+ return_full_text=False,
72
+ max_new_tokens=400,
73
+ )
74
+
75
+ self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
76
+
77
+ def get_llm(self):
78
+ return self.llm
79
+
80
+ # News Detector Class
81
+ class NewsDetector:
82
+ def __init__(self, retriever, llm):
83
+ self.retriever = retriever
84
+ self.llm = llm
85
+ self.chat_history = []
86
+
87
+ # System prompt for detecting fake news based on verified documents
88
+ system_prompt = (
89
+ "You are an assistant for detecting fake news. You have access to a set of documents that contain only verified and true news. "
90
+ "When a user asks a question or provides a statement, your task is to search these documents to verify the authenticity of the input.\n\n"
91
+ "If the input matches the true news, respond: 'The statement appears to be true based on verified information.'\n"
92
+ "If the input contradicts the true news, respond: 'The statement appears to be false based on verified information.'\n"
93
+ "If there is not enough information to verify the statement, respond: 'I'm unable to verify the statement with the available data.'"
94
+ )
95
+
96
+ self.qa_prompt = ChatPromptTemplate.from_messages([
97
+ ("system", system_prompt),
98
+ ("human", "{input}"),
99
+ ])
100
+
101
+ self.question_answer_chain = create_stuff_documents_chain(self.llm, self.qa_prompt)
102
+ self.rag_chain = create_retrieval_chain(self.retriever, self.question_answer_chain)
103
+
104
+
105
+
106
+ def respond(self, message):
107
+ response = self.rag_chain.invoke(
108
+ {"input": message})
109
+
110
+ return response["answer"]
111
+
112
+
113
+ # Create a Gradio Interface for the chatbot
114
+ def chatbot_response(user_input):
115
+ response = news_detector.respond(user_input)
116
+ return response
117
+
118
+ # Main Execution
119
+ if __name__ == "__main__":
120
+ # Initialize and process documents
121
+ processor = DocumentProcessor()
122
+ processor.process_documents("data/") # Path to the directory containing PDF files
123
+
124
+ # Initialize and load the model
125
+ model_handler = ModelHandler()
126
+ model_handler.load_model()
127
+
128
+ # Create the news detector with the retriever and the language model
129
+ news_detector = NewsDetector(retriever=processor.get_retriever(), llm=model_handler.get_llm())
130
+
131
+ # Gradio Interface
132
+ with gr.Blocks() as demo:
133
+ gr.Markdown("# News Verification")
134
+ with gr.Row():
135
+ with gr.Column():
136
+ user_input = gr.Textbox(label="Enter your statement:")
137
+ with gr.Column():
138
+ output_text = gr.Textbox(label="Response")
139
+ submit_button = gr.Button("Submit")
140
+
141
+ submit_button.click(fn=chatbot_response, inputs=user_input, outputs=output_text)
142
+
143
+ demo.launch()
144
+