Spaces:
Running
Running
""" | |
Module for detecting fallacies in text. | |
""" | |
import os | |
import re | |
import time | |
import json | |
import csv | |
from ast import literal_eval | |
from collections import namedtuple | |
import requests | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain_community.chat_models.huggingface import ChatHuggingFace | |
from langchain.agents import AgentExecutor, load_tools, create_react_agent | |
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser | |
from langchain.tools import Tool | |
from langchain.tools import DuckDuckGoSearchRun | |
from .templates import ( | |
REACT, | |
INCONTEXT, | |
SUMMARIZATION, | |
CONCLUDING, | |
CONCLUDING_INCONTEXT, | |
) | |
from .definitions import DEFINITIONS | |
from .examples import FALLACY_CLAIMS, DEBUNKINGS | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.environ.get("HF_API_KEY") | |
class HamburgerStyle: | |
def __init__(self): | |
# hamburger-style structure: | |
self.heading = namedtuple("Heading", ["name", "content"]) | |
self.hamburger = [ | |
self.heading(name="Myth", content=None), | |
self.heading(name="##FACT", content=None), | |
self.heading(name="##MYTH", content=None), | |
self.heading(name="##FALLACY", content=None), | |
self.heading(name="##FACT", content=None), | |
] | |
self.llm = HuggingFaceEndpoint( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
temperature=1, | |
top_k=1, | |
model_kwargs={ | |
"use_cache": False, | |
}, | |
) | |
self.chat_model = ChatHuggingFace(llm=self.llm) | |
self.flicc_model = "fzanartu/flicc" | |
self.card_model = "crarojasca/BinaryAugmentedCARDS" | |
self.semantic_textual_similarity = "sentence-transformers/all-MiniLM-L6-v2" | |
self.taxonomy_cards = "crarojasca/TaxonomyAugmentedCARDS" | |
self.dirname = os.path.dirname(os.path.abspath("__file__")) | |
self.filename = os.path.join(self.dirname, "structured/climate_fever_cards.csv") | |
def generate_st_layer(self, misinformation): | |
## FACT: ReAct | |
prompt = REACT | |
# define the agent | |
chat_model_with_stop = self.chat_model.bind(stop=["\nObservation"]) | |
agent = ( | |
{ | |
"input": lambda x: x["input"], | |
"agent_scratchpad": lambda x: format_log_to_str( | |
x["intermediate_steps"] | |
), | |
} | |
| prompt | |
| self.chat_model | |
| ReActJsonSingleInputOutputParser() | |
) | |
search = DuckDuckGoSearchRun() | |
tools = [ | |
Tool( | |
name="google_search", | |
description="Search Google for recent results.", | |
func=search.run, | |
) | |
] | |
agent = create_react_agent(chat_model_with_stop, tools, prompt) | |
agent_executor = AgentExecutor( | |
agent=agent, tools=tools, verbose=False, handle_parsing_errors=True | |
) | |
return agent_executor.invoke({"input": misinformation}).get("output") | |
def generate_nd_layer(self, misinformation): | |
## MYTH: Summ | |
prompt = SUMMARIZATION | |
chain = prompt | self.llm | |
return chain.invoke({"text": misinformation}) | |
def generate_rd_layer(self, misinformation): | |
## FALLACY: Fallacy | |
# 1 predict fallacy label in FLICC taxonomy | |
detected_fallacy = self.endpoint_query( | |
model=self.flicc_model, payload=misinformation | |
)[0][0].get("label") | |
fallacy_definition = DEFINITIONS.get(detected_fallacy) | |
# 2 get all examples with the same label | |
claims = FALLACY_CLAIMS.get(detected_fallacy, None) | |
# 3 get cosine similarity for all claims and myth | |
example_myths = self.endpoint_query( | |
payload={"source_sentence": misinformation, "sentences": claims}, | |
model=self.semantic_textual_similarity, | |
) | |
# 3 # get most similar claim and FACT | |
max_similarity = example_myths.index(max(example_myths)) | |
example_myth = claims[max_similarity] | |
example_response = DEBUNKINGS.get(claims[max_similarity]) | |
fact = re.findall(r"## FALLACY:.*?(?=##)", example_response, re.DOTALL)[ | |
0 | |
] # get only the fallacy layer from the example. | |
fact = fact.replace("## FALLACY:", "") | |
prompt = INCONTEXT | |
chain = prompt | self.chat_model | |
content = chain.invoke( | |
{ | |
"misinformation": misinformation, | |
"detected_fallacy": detected_fallacy, | |
"fallacy_definition": fallacy_definition, | |
"example_response": fact, | |
"example_myth": example_myth, | |
"factual_information": self.hamburger[1].content, | |
} | |
).content | |
content = re.sub(r"Response:", "", content) | |
return content | |
def generate_th_layer(self, misinformation): | |
## FACT: Concluding | |
cards_label = self.endpoint_query( | |
model=self.taxonomy_cards, payload=misinformation | |
)[0][0].get("label") | |
# 1 get all claims with same label from FEVER dataset | |
claims = self.get_fever_claims(cards_label) # TODO | |
prompt_completition = {"fact": self.hamburger[1].content} | |
if claims: | |
prompt = CONCLUDING_INCONTEXT | |
example_myths = self.endpoint_query( | |
payload={ | |
"input": {"source_sentence": misinformation, "sentences": claims} | |
}, | |
model=self.semantic_textual_similarity, | |
) | |
max_similarity = example_myths.index(max(example_myths)) | |
example_myth = claims[max_similarity] | |
complementary_details = self.get_fever_evidence(example_myth) # TODO | |
prompt_completition.update({"complementary_details": complementary_details}) | |
else: | |
prompt = CONCLUDING | |
chain = prompt | self.llm | |
return chain.invoke(prompt_completition) | |
def rebuttal_generator(self, misinformation): | |
# generate rebuttal | |
self.hamburger[0] = self.hamburger[0]._replace(content=misinformation) | |
## FACT | |
self.hamburger[1] = self.hamburger[1]._replace( | |
content=self.generate_st_layer(misinformation).strip() | |
) | |
## MYTH | |
self.hamburger[2] = self.hamburger[2]._replace( | |
content=self.generate_nd_layer(misinformation).strip() | |
) | |
## FALLACY | |
self.hamburger[3] = self.hamburger[3]._replace( | |
content=self.generate_rd_layer(misinformation).strip() | |
) | |
## FACT | |
self.hamburger[4] = self.hamburger[4]._replace( | |
content=self.generate_th_layer(misinformation).strip() | |
) | |
# compose and format the string | |
rebuttal = f"""{self.hamburger[1].name}: {self.hamburger[1].content}\n{self.hamburger[2].name}: {self.hamburger[2].content}\n{self.hamburger[3].name}: {self.hamburger[3].content}\n{self.hamburger[4].name}: {self.hamburger[4].content}""" | |
return rebuttal | |
def endpoint_query(self, payload, model): | |
headers = {"Authorization": f"Bearer {os.environ['HUGGINGFACEHUB_API_TOKEN']}"} | |
options = {"use_cache": False, "wait_for_model": True} | |
payload = {"inputs": payload, "options": options} | |
api_url = f"https://api-inference.huggingface.co/models/{model}" | |
response = requests.post(api_url, headers=headers, json=payload, timeout=120) | |
return json.loads(response.content.decode("utf-8")) | |
def retry_on_exceptions(self, function, *args): | |
attempt = 0 | |
while attempt < 5: | |
try: | |
return function(*args) | |
except (KeyError, ValueError): | |
print("retrying %d out of 5", attempt + 1) | |
time.sleep(5 * (attempt + 1)) | |
attempt += 1 | |
continue | |
# Return None if no response after five attempts | |
return None | |
def get_fever_claims(self, label): | |
claims = [] | |
with open(self.filename, "r", encoding="utf-8") as csvfile: | |
reader = csv.DictReader(csvfile) | |
for row in reader: | |
if row["claim_label"] == 1 and row["CARDS_label"] == label: | |
claims.append(row["claim"]) | |
return claims | |
def get_fever_evidence(self, claim): | |
evidences = [] | |
with open(self.filename, "r", encoding="utf-8") as csvfile: | |
reader = csv.DictReader(csvfile) | |
for row in reader: | |
if row["claim_label"] == 1 and row["claim"] == claim: | |
for evidence_dict in literal_eval(row["evidences"]): | |
evidences.append(evidence_dict["evidence"]) | |
return "\n".join("* " + evidence for evidence in evidences) | |