Manikandan-Alagu
commited on
Commit
•
8ebebd8
1
Parent(s):
f309879
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
|
5 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
6 |
+
from langchain_chroma import Chroma
|
7 |
+
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
|
8 |
+
from langchain_core.prompts import ChatPromptTemplate
|
9 |
+
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
|
10 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
11 |
+
from langchain_community.vectorstores.utils import filter_complex_metadata
|
12 |
+
|
13 |
+
# Document Processor Class
|
14 |
+
class DocumentProcessor:
|
15 |
+
def __init__(self):
|
16 |
+
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
17 |
+
self.vectorstore = None
|
18 |
+
self.retriever = None
|
19 |
+
|
20 |
+
def process_documents(self, directory_path):
|
21 |
+
|
22 |
+
all_splits = []
|
23 |
+
try:
|
24 |
+
loader = DirectoryLoader(directory_path, glob="*.pdf", loader_cls=PyPDFLoader)
|
25 |
+
data = loader.load()
|
26 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
27 |
+
all_splits += text_splitter.split_documents(data)
|
28 |
+
except Exception as e:
|
29 |
+
print(f"Error loading documents: {e}")
|
30 |
+
return
|
31 |
+
|
32 |
+
doc = filter_complex_metadata(all_splits)
|
33 |
+
self.vectorstore = Chroma.from_documents(documents=doc, embedding=self.embeddings)
|
34 |
+
self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
35 |
+
|
36 |
+
def get_retriever(self):
|
37 |
+
return self.retriever
|
38 |
+
|
39 |
+
# Model Handler Class
|
40 |
+
class ModelHandler:
|
41 |
+
def __init__(self):
|
42 |
+
script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
|
43 |
+
self.model_cache_dir = os.path.join(script_dir, "model_cache") # Cache in the script directory
|
44 |
+
self.llm = None
|
45 |
+
|
46 |
+
def load_model(self):
|
47 |
+
model_name = "HuggingFaceH4/zephyr-7b-beta"
|
48 |
+
|
49 |
+
if os.path.exists(self.model_cache_dir):
|
50 |
+
print("Loading model from cache...")
|
51 |
+
model = AutoModelForCausalLM.from_pretrained(self.model_cache_dir)
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_cache_dir)
|
53 |
+
else:
|
54 |
+
print("Downloading and caching model...")
|
55 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
57 |
+
os.makedirs(self.model_cache_dir, exist_ok=True)
|
58 |
+
model.save_pretrained(self.model_cache_dir) # Cache the model in the script directory
|
59 |
+
tokenizer.save_pretrained(self.model_cache_dir)
|
60 |
+
|
61 |
+
tokenizer.pad_token = tokenizer.eos_token
|
62 |
+
tokenizer.padding_side = "right"
|
63 |
+
|
64 |
+
text_generation_pipeline = pipeline(
|
65 |
+
model=model,
|
66 |
+
tokenizer=tokenizer,
|
67 |
+
task="text-generation",
|
68 |
+
temperature=0.2,
|
69 |
+
do_sample=True,
|
70 |
+
repetition_penalty=1.1,
|
71 |
+
return_full_text=False,
|
72 |
+
max_new_tokens=400,
|
73 |
+
)
|
74 |
+
|
75 |
+
self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
|
76 |
+
|
77 |
+
def get_llm(self):
|
78 |
+
return self.llm
|
79 |
+
|
80 |
+
# News Detector Class
|
81 |
+
class NewsDetector:
|
82 |
+
def __init__(self, retriever, llm):
|
83 |
+
self.retriever = retriever
|
84 |
+
self.llm = llm
|
85 |
+
self.chat_history = []
|
86 |
+
|
87 |
+
# System prompt for detecting fake news based on verified documents
|
88 |
+
system_prompt = (
|
89 |
+
"You are an assistant for detecting fake news. You have access to a set of documents that contain only verified and true news. "
|
90 |
+
"When a user asks a question or provides a statement, your task is to search these documents to verify the authenticity of the input.\n\n"
|
91 |
+
"If the input matches the true news, respond: 'The statement appears to be true based on verified information.'\n"
|
92 |
+
"If the input contradicts the true news, respond: 'The statement appears to be false based on verified information.'\n"
|
93 |
+
"If there is not enough information to verify the statement, respond: 'I'm unable to verify the statement with the available data.'"
|
94 |
+
)
|
95 |
+
|
96 |
+
self.qa_prompt = ChatPromptTemplate.from_messages([
|
97 |
+
("system", system_prompt),
|
98 |
+
("human", "{input}"),
|
99 |
+
])
|
100 |
+
|
101 |
+
self.question_answer_chain = create_stuff_documents_chain(self.llm, self.qa_prompt)
|
102 |
+
self.rag_chain = create_retrieval_chain(self.retriever, self.question_answer_chain)
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def respond(self, message):
|
107 |
+
response = self.rag_chain.invoke(
|
108 |
+
{"input": message})
|
109 |
+
|
110 |
+
return response["answer"]
|
111 |
+
|
112 |
+
|
113 |
+
# Create a Gradio Interface for the chatbot
|
114 |
+
def chatbot_response(user_input):
|
115 |
+
response = news_detector.respond(user_input)
|
116 |
+
return response
|
117 |
+
|
118 |
+
# Main Execution
|
119 |
+
if __name__ == "__main__":
|
120 |
+
# Initialize and process documents
|
121 |
+
processor = DocumentProcessor()
|
122 |
+
processor.process_documents("data/") # Path to the directory containing PDF files
|
123 |
+
|
124 |
+
# Initialize and load the model
|
125 |
+
model_handler = ModelHandler()
|
126 |
+
model_handler.load_model()
|
127 |
+
|
128 |
+
# Create the news detector with the retriever and the language model
|
129 |
+
news_detector = NewsDetector(retriever=processor.get_retriever(), llm=model_handler.get_llm())
|
130 |
+
|
131 |
+
# Gradio Interface
|
132 |
+
with gr.Blocks() as demo:
|
133 |
+
gr.Markdown("# News Verification")
|
134 |
+
with gr.Row():
|
135 |
+
with gr.Column():
|
136 |
+
user_input = gr.Textbox(label="Enter your statement:")
|
137 |
+
with gr.Column():
|
138 |
+
output_text = gr.Textbox(label="Response")
|
139 |
+
submit_button = gr.Button("Submit")
|
140 |
+
|
141 |
+
submit_button.click(fn=chatbot_response, inputs=user_input, outputs=output_text)
|
142 |
+
|
143 |
+
demo.launch()
|
144 |
+
|