import os import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from langchain.chains import create_retrieval_chain, create_history_aware_retriever from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_chroma import Chroma from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader from langchain_core.prompts import ChatPromptTemplate from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores.utils import filter_complex_metadata # Document Processor Class class DocumentProcessor: def __init__(self): self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") self.vectorstore = None self.retriever = None def process_documents(self, directory_path): all_splits = [] try: loader = DirectoryLoader(directory_path, glob="*.pdf", loader_cls=PyPDFLoader) data = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) all_splits += text_splitter.split_documents(data) except Exception as e: print(f"Error loading documents: {e}") return doc = filter_complex_metadata(all_splits) self.vectorstore = Chroma.from_documents(documents=doc, embedding=self.embeddings) self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) def get_retriever(self): return self.retriever # Model Handler Class class ModelHandler: def __init__(self): script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script self.model_cache_dir = os.path.join(script_dir, "model_cache") # Cache in the script directory self.llm = None def load_model(self): model_name = "HuggingFaceH4/zephyr-7b-beta" if os.path.exists(self.model_cache_dir): print("Loading model from cache...") model = AutoModelForCausalLM.from_pretrained(self.model_cache_dir) tokenizer = AutoTokenizer.from_pretrained(self.model_cache_dir) else: print("Downloading and caching model...") model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) os.makedirs(self.model_cache_dir, exist_ok=True) model.save_pretrained(self.model_cache_dir) # Cache the model in the script directory tokenizer.save_pretrained(self.model_cache_dir) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" text_generation_pipeline = pipeline( model=model, tokenizer=tokenizer, task="text-generation", temperature=0.2, do_sample=True, repetition_penalty=1.1, return_full_text=False, max_new_tokens=400, ) self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline) def get_llm(self): return self.llm # News Detector Class class NewsDetector: def __init__(self, retriever, llm): self.retriever = retriever self.llm = llm self.chat_history = [] # System prompt for detecting fake news based on verified documents system_prompt = ( "You are an assistant for detecting fake news. You have access to a set of documents that contain only verified and true news. " "When a user asks a question or provides a statement, your task is to search these documents to verify the authenticity of the input.\n\n" "If the input matches the true news, respond: 'The statement appears to be true based on verified information.'\n" "If the input contradicts the true news, respond: 'The statement appears to be false based on verified information.'\n" "If there is not enough information to verify the statement, respond: 'I'm unable to verify the statement with the available data.'" ) self.qa_prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), ("human", "{input}"), ]) self.question_answer_chain = create_stuff_documents_chain(self.llm, self.qa_prompt) self.rag_chain = create_retrieval_chain(self.retriever, self.question_answer_chain) def respond(self, message): response = self.rag_chain.invoke( {"input": message}) return response["answer"] # Create a Gradio Interface for the chatbot def chatbot_response(user_input): response = news_detector.respond(user_input) return response # Main Execution if __name__ == "__main__": # Initialize and process documents processor = DocumentProcessor() processor.process_documents("data/") # Path to the directory containing PDF files # Initialize and load the model model_handler = ModelHandler() model_handler.load_model() # Create the news detector with the retriever and the language model news_detector = NewsDetector(retriever=processor.get_retriever(), llm=model_handler.get_llm()) # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# News Verification") with gr.Row(): with gr.Column(): user_input = gr.Textbox(label="Enter your statement:") with gr.Column(): output_text = gr.Textbox(label="Response") submit_button = gr.Button("Submit") submit_button.click(fn=chatbot_response, inputs=user_input, outputs=output_text) demo.launch()