from typing import List, Optional import torch import streamlit as st import pandas as pd import random import time import logging from json import JSONDecodeError from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig from haystack import Document from haystack.document_stores import FAISSDocumentStore from haystack.modeling.utils import initialize_device_settings from haystack.nodes import EmbeddingRetriever from haystack.pipelines import Pipeline from haystack.nodes.base import BaseComponent from haystack.schema import Document from Project.Fact_Checking_Blue_Amazon.config import ( RETRIEVER_TOP_K, RETRIEVER_MODEL, NLI_MODEL, ) class EntailmentChecker(BaseComponent): """ This node checks the entailment between every document content and the statement. It enrichs the documents metadata with entailment informations. It also returns aggregate entailment information. """ outgoing_edges = 1 def __init__( self, model_name_or_path: str = "roberta-large-mnli", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, batch_size: int = 100, entailment_contradiction_consideration: float = 0.6, entailment_contradiction_threshold: float = 0.8 ): """ Load a Natural Language Inference model from Transformers. :param model_name_or_path: Directory of a saved model or the name of a public model. See https://huggingface.co/models for full list of available models. :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :param tokenizer: Name of the tokenizer (usually the same as model) :param use_gpu: Whether to use GPU (if available). :param batch_size: Number of Documents to be processed at a time. :param entailment_contradiction_threshold: Only consider sentences that have entailment or contradiction score greater than this param. """ super().__init__() self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) tokenizer = tokenizer or model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) self.model = AutoModelForSequenceClassification.from_pretrained( pretrained_model_name_or_path=model_name_or_path, revision=model_version ) self.batch_size = batch_size self.entailment_contradiction_threshold = entailment_contradiction_threshold self.entailment_contradiction_consideration = entailment_contradiction_consideration self.model.to(str(self.devices[0])) id2label = AutoConfig.from_pretrained(model_name_or_path).id2label self.labels = [id2label[k].lower() for k in sorted(id2label)] if "entailment" not in self.labels: raise ValueError("The model config must contain entailment value in the id2label dict.") def run(self, query: str, documents: List[Document]): scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0 premise_batch = [doc.content for doc in documents] hypothesis_batch = [query] * len(documents) entailment_info_batch = self.get_entailment_batch( premise_batch=premise_batch, hypothesis_batch=hypothesis_batch ) considered_documents = [] for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)): doc.meta["entailment_info"] = entailment_info con, neu, ent = ( entailment_info["contradiction"], entailment_info["neutral"], entailment_info["entailment"], ) if (con > self.entailment_contradiction_consideration) or (ent > self.entailment_contradiction_consideration): considered_documents.append(doc) agg_con += con agg_neu += neu agg_ent += ent scores += 1 if max(agg_con, agg_ent)/scores > self.entailment_contradiction_threshold: break # if in the first documents there is a strong evidence of entailment/contradiction, # there is no need to consider less relevant documents aggregate_entailment_info = { "contradiction": round(agg_con / scores, 2), "neutral": round(agg_neu / scores, 2), "entailment": round(agg_ent / scores, 2), } entailment_checker_result = { "documents": considered_documents[: i + 1], "aggregate_entailment_info": aggregate_entailment_info, } return entailment_checker_result def get_entailment_dict(self, probs): return {k.lower(): v for k, v in zip(self.labels, probs)} def get_entailment_batch(self, premise_batch: List[str], hypothesis_batch: List[str]): formatted_texts = [ f"{premise}{self.tokenizer.sep_token}{hypothesis}" for premise, hypothesis in zip(premise_batch, hypothesis_batch) ] with torch.inference_mode(): inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to( self.devices[0] ) out = self.model(**inputs) logits = out.logits probs_batch = torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() return [self.get_entailment_dict(probs) for probs in probs_batch] # cached to make index and models load only at start @st.cache_resource def start_haystack(): """ load document store, retriever, entailment checker and create pipeline """ document_store = FAISSDocumentStore( faiss_index_path=f"../data//my_faiss_index.faiss", faiss_config_path=f"../data/my_faiss_index.json", ) print(f"Index size: {document_store.get_document_count()}") retriever = EmbeddingRetriever( document_store=document_store, embedding_model=RETRIEVER_MODEL ) entailment_checker = EntailmentChecker( model_name_or_path=NLI_MODEL, use_gpu=False, ) pipe = Pipeline() pipe.add_node(component=retriever, name="retriever", inputs=["Query"]) pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"]) return pipe pipe = start_haystack() @st.cache_resource def check_statement(pipe, statement: str, retriever_top_k: int = 5): """Run query and verify statement""" params = {"retriever": {"top_k": retriever_top_k}} return pipe.run(statement, params=params) def set_state_if_absent(key, value): if key not in st.session_state: st.session_state[key] = value # Small callback to reset the interface in case the text of the question changes def reset_results(*args): st.session_state.answer = None st.session_state.results = None st.session_state.raw_json = None def create_df_for_relevant_snippets(docs): """ Create a dataframe that contains all relevant snippets. """ rows = [] for doc in docs: row = { "Content": doc.content, "con": f"{doc.meta['entailment_info']['contradiction']:.2f}", "neu": f"{doc.meta['entailment_info']['neutral']:.2f}", "ent": f"{doc.meta['entailment_info']['entailment']:.2f}", } rows.append(row) df = pd.DataFrame(rows) df["Content"] = df["Content"].str.wrap(75) df = df.style.apply(highlight_cols) return df def highlight_cols(s): coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"} if s.name in coldict.keys(): return ["background-color: {}".format(coldict[s.name])] * len(s) return [""] * len(s) def main(): # Persistent state set_state_if_absent("statement", "") set_state_if_absent("answer", "") set_state_if_absent("results", None) set_state_if_absent("raw_json", None) st.write("# Verificação de Sentenças sobre Amazônia Azul") st.write() st.markdown( """ ##### Insira uma sentença sobre a amazônia azul. """ ) # Search bar statement = st.text_input( "", value=st.session_state.statement, max_chars=100, on_change=reset_results ) st.markdown("", unsafe_allow_html=True) run_pressed = st.button("Run") run_query = ( run_pressed or statement != st.session_state.statement ) # Get results for query if run_query and statement: time_start = time.time() reset_results() st.session_state.statement = statement with st.spinner("   Procurando a Similaridade no banco de sentenças..."): try: st.session_state.results = check_statement(statement, RETRIEVER_TOP_K) print(f"S: {statement}") time_end = time.time() print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) print(f"elapsed time: {time_end - time_start}") except JSONDecodeError as je: st.error( "👓    Erro na document store." ) return except Exception as e: logging.exception(e) st.error("🐞    Erro Genérico.") return # Display results if st.session_state.results: docs = st.session_state.results["documents"] agg_entailment_info = st.session_state.results["aggregate_entailment_info"] st.markdown(f"###### Aggregate entailment information:") st.write(agg_entailment_info) st.markdown(f"###### Most Relevant snippets:") df = create_df_for_relevant_snippets(docs) st.dataframe(df) main()