deepa2-demo / aaac_util.py
ggbetz's picture
Update app.
73c1565
### utility for T5-aaac
import re
import ast
import logging
from string import Template
import random
import pyparsing as pp
import z3
PREM_ARG_SCHEMES = {
'modus ponens': 2,
'chain rule': 2,
'adjunction': 2,
'case analysis': 3,
'disjunctive syllogism': 2,
'biconditional elimination': 1,
'instantiation': 1,
'hypothetical syllogism': 2,
'generalized biconditional elimination': 1,
'generalized adjunction': 2,
'generalized dilemma': 3,
'generalized disjunctive syllogism': 2
}
util_logger = logging.getLogger('transformer_tools.util.t5_util')
#######################################
# Layouter class #
#######################################
# Defines how to present AAAC raw data to model (as text)
class AAACLayouter:
PRED_CHARS = "FGHIJKLMNOPQRSTUVWABCDE"
ENT_CHARS = "abcdeklmnopqrstuwfgh"
def substitutions():
substitutions = {"F"+str(i+1):AAACLayouter.PRED_CHARS[i]+" " for i in range(20)}
substitutions.update({"a"+str(i+1):AAACLayouter.ENT_CHARS[i] for i in range(20)})
return substitutions
MASK_STRING = "??"
# defines how to present reason and conclusion statements to the model
def format_statements_list(statements:list, mask_prob:float=0.0) -> str:
if len(statements)==0:
return "None"
def ref_reco(sdict):
r = "(%s)" % sdict['ref_reco'] if random.random()>mask_prob else AAACLayouter.MASK_STRING
return r
list_as_string = ["%s (ref: %s)" % (sdict['text'].lower(),ref_reco(sdict)) for sdict in statements]
list_as_string = " | ".join(list_as_string)
return list_as_string
# defines how to present argdown premise and conclusion statements to the model
def format_ad_statements_list(statements:list, mask_prob:float=0.0) -> str:
if len(statements)==0:
return "None"
def ref_reco(sdict):
r = "(%s)" % sdict['ref_reco'] if random.random()>mask_prob else AAACLayouter.MASK_STRING
return r
def explicit(sdict):
r = str(sdict['explicit']) if random.random()>mask_prob else AAACLayouter.MASK_STRING
return r
list_as_string = ["%s (ref: %s explicit: %s)" % (sdict['text'].lower(),ref_reco(sdict),explicit(sdict)) for sdict in statements]
list_as_string = " | ".join(list_as_string)
return list_as_string
# defines how to present formalizations to the model
def format_formalizations_list(formalizations:list, mask_prob:float=0.0) -> str:
if len(formalizations)==0:
return "None"
def ref_reco(sdict):
r = "(%s)" % sdict['ref_reco'] if random.random()>mask_prob else AAACLayouter.MASK_STRING
return r
def fform(sdict):
t = Template(sdict['form'])
r = t.substitute(AAACLayouter.substitutions())
r = r.replace("¬","not ")
return r
list_as_string = ["%s (ref: %s)" % (fform(sdict),ref_reco(sdict)) for sdict in formalizations]
list_as_string = " | ".join(list_as_string)
return list_as_string
# defines how to present formalizations to the model
def format_plcd_subs(plcd_subs:dict, mask_prob:float=0.0) -> str:
if len(plcd_subs.keys())==0:
return "None"
def mask(s):
return s if random.random()>mask_prob else AAACLayouter.MASK_STRING
list_as_string = ["%s: %s" % (AAACLayouter.substitutions()[k],mask(v.lower())) for k,v in plcd_subs.items()]
list_as_string = " | ".join(list_as_string)
return list_as_string
# defines how to present argdown-snippet to the model
def format_argdown(argdown: str, mask_prob:float=0.0) -> str:
# pattern = r"({.*uses: \[[\s\d,]*\]})" # matches yaml metadata inline blocks in inference patterns
pattern = r"--\nwith ([^{}]*)({[^{}]*})"
matches = re.findall(pattern, argdown)
for match in matches:
m = match[1].replace('uses:','"uses":')
m = m.replace('variant:','"variant":')
d = ast.literal_eval(m)
subst = ""
mask_b = random.random()<mask_prob
if mask_b:
subst= "?? "
elif "variant" in d:
subst = "(%s) " % ", ".join(d['variant'])
subst = subst + "from " + " ".join(["(%s)" % i for i in d['uses']])
if mask_b:
argdown = argdown.replace(match[0]+match[1],subst)
else:
argdown = argdown.replace(match[1],subst)
argdown = argdown.replace("\n"," ") # replace line breaks
argdown = argdown.lower()
return argdown
#######################################
# Parser classes #
#######################################
class AAACParser:
def parse_proposition_block(ad_raw:str,inf_args:dict=None):
if not ad_raw:
return []
ad_raw = ad_raw
if ad_raw[0]!=" ":
ad_raw = " "+ad_raw
regex = r" \(([0-9]*)\) " # match labels
proposition_list = []
if not re.match(regex,ad_raw):
return proposition_list
matches = re.finditer(regex, ad_raw, re.MULTILINE)
label = -1
pointer = -1
for matchNum, match in enumerate(matches, start=1):
if label>-1:
proposition = {
"text":ad_raw[pointer:match.start()].strip(),
"label":label,
"uses": [],
"scheme": "",
"variants":[]
}
proposition_list.append(proposition)
label = int(match.group(1))
pointer = match.end()
if label>-1:
proposition = {'text':ad_raw[pointer:].strip() ,'label':label,"uses": [],"scheme": "","variants":[]}
proposition_list.append(proposition)
if proposition_list and inf_args:
proposition_list[0].update(inf_args)
return proposition_list
def parse_variants(variants_raw)->list:
if not variants_raw:
return []
regex = r"(?! )[^\(\),]+"
matches = re.finditer(regex, str(variants_raw), re.MULTILINE)
return [match.group() for match in matches]
def parse_uses(uses_raw)->list:
if not uses_raw:
return []
regex = r"\(([0-9]+)\)"
matches = re.finditer(regex, str(uses_raw), re.MULTILINE)
return [int(match.group(1)) for match in matches]
def preprocess_ad(ad_raw:str):
ad_raw = ad_raw.replace("\n"," ")
ad_raw = ad_raw.replace(" "," ")
ad_raw = ad_raw.replace("with?? ","with ?? ")
return ad_raw
def parse_argdown_block(ad_raw:str):
ad_raw = AAACParser.preprocess_ad(ad_raw)
regex = r"-- with ([^\(\)]*)( \([^-\(\))]*\))? from ([\(\), 0-9]+) --" # matches inference patterns
proposition_list = []
inf_args = None
pointer = 0
matches = re.finditer(regex, ad_raw, re.MULTILINE)
for matchNum, match in enumerate(matches, start=1):
# parse all propositions before inference matched
propositions = AAACParser.parse_proposition_block(ad_raw[pointer:match.start()],inf_args=inf_args)
if not propositions:
return None
proposition_list.extend(propositions)
# update pointer and inf_args to be used for parsing next propositions block
pointer = match.end()
inf_args = {
'scheme': match.group(1),
'variants': AAACParser.parse_variants(match.group(2)),
'uses': AAACParser.parse_uses(match.group(3))
}
if pointer > 0:
propositions = AAACParser.parse_proposition_block(ad_raw[pointer:],inf_args=inf_args)
proposition_list.extend(propositions)
return proposition_list
def parse_statements(statements_raw:str):
if not statements_raw:
return None
statements = []
if statements_raw.strip()=="None":
return statements
list_raw = statements_raw.split(" | ")
regex = r" \(ref: (?:\(([0-9]+)\)|\?\?)\)$"
for s in list_raw:
match = re.search(regex, s)
if not match:
return None
item = {
'text':s[:match.start()],
'ref_reco':int(match.group(1)) if match.group(1) else match.group(1)
}
statements.append(item)
return statements
def parse_formalizations(forms_raw:str):
parsed = AAACParser.parse_statements(forms_raw)
if not parsed:
return None
formalizations = []
for d in parsed:
d['form'] = d.pop('text')
formalizations.append(d)
# post-process: cleanup "⁇"
for f in formalizations:
form = f['form']
form = form.replace("⁇","")
form = form.replace(" "," ")
form = form.strip()
f['form'] = form
return formalizations
def parse_plcd_subs(subs_raw:str):
if not subs_raw:
return None
plcd_subs = []
if subs_raw.strip()=="None":
return plcd_subs
list_raw = subs_raw.split(" | ")
regex = r"^(..?):\s(.+)"
for s in list_raw:
match = re.search(regex, s)
if not match:
return None
k = match.group(1)
# comment out reverse substitution
#if k in AAACLayouter.PRED_CHARS:
# k = 'F'+str(1+AAACLayouter.PRED_CHARS.index(k))
#if k in AAACLayouter.ENT_CHARS:
# k = 'a'+str(1+AAACLayouter.ENT_CHARS.index(k))
item = {k: match.group(2)}
plcd_subs.append(item)
return plcd_subs
#######################################
# Logic Evaluator Class #
#######################################
class AAACLogicEvaluator():
def __init__(
self,
nl_schemes = None,
domains = None,
**kwargs
):
self.nl_schemes_re = []
if nl_schemes:
self.nl_schemes_re = nl_schemes.copy()
for item in self.nl_schemes_re:
item['scheme'] = [self.construct_regex(s) for s in item['scheme']]
# construct de-paraphrasing rules from domain-config-file
self.de_paraphrasing_rules = {}
if domains:
for domain in domains.get('domains'):
rules = {}
for k,v in domain.get("paraphrases",{}).items():
rules.update({repl.lower():k.lower() for repl in v}) # all paraphrasing rules are cast as lower case
self.de_paraphrasing_rules[domain['id']] = rules
def construct_regex(self,statement:str):
regex = r"( a )?\$\{([A-Za-z])\}"
regex_template = ""
pointer = 0
matches = re.finditer(regex, statement, re.MULTILINE)
for matchNum, match in enumerate(matches):
regex_template += statement[pointer:match.start()].lower()
neg_la = statement[match.end():match.end()+5] if match.end()+5<=len(statement) else statement[match.end():]
if match.group(1):
regex_template += " an? "
regex_template += "(?P<%s%s>.*(?!%s)*)"%(match.group(2),matchNum,neg_la)
pointer = match.end()
regex_template += statement[pointer:].lower()
return regex_template
def parse_inference_as_scheme(
self,
argument:list = None,
nl_scheme_re:dict = None
):
# recursively try to match premises to scheme, i.e. recursively construct a consistent mapping
# mapping maps sentences in
# matching contains all matches found so far
def match_premises(matching:list=None, mapping:dict=None):
#print("match_premises:" + str(mapping))
unmapped_formulas = [i for i in range(len(argument)) if not i in mapping.keys()]
unmapped_premises = [j for j in range(len(argument)) if not j in mapping.values()]
for i in unmapped_formulas:
for j in unmapped_premises:
try:
match = re.match(nl_scheme_re['scheme'][i], argument[j])
except IndexError:
match=False
if match:
matching[i]=match
mapping[i]=j
if any(m==None for m in matching):
full_match = match_premises(
matching = matching,
mapping = mapping
)
else:
full_match = matching_consistent(matching)
if full_match:
return True
else:
full_match = False
return full_match
# check whether a mapping is consistent with respect to placeholders
def matching_consistent(matching:list):
if any(m==None for m in matching):
return False
def group_by_name(match=None, group_name=None):
try:
g=match.group(group_name)
except IndexError:
g=None
return g
all_plcds = (nl_scheme_re["predicate-placeholders"]+nl_scheme_re["entity-placeholders"])
for plcd in all_plcds:
all_subst = []
for i in range(10):
group_name = plcd+str(i)
subst = [group_by_name(match=match, group_name=group_name) for match in matching]
subst = [x for x in subst if x != None]
all_subst.extend(subst)
if len(set(all_subst))>1:
return False
return True
c_match = re.match(nl_scheme_re['scheme'][-1], argument[-1])
if c_match:
try:
g=c_match.group("F")
except IndexError:
g="None"
if c_match:
full_match = match_premises(
matching = [None]*(len(argument)-1)+[c_match],
mapping = {len(argument)-1:len(argument)-1}
)
else:
full_match = False
return full_match
def parse_inference_as_base_scheme(
self,
argument:list = None,
base_scheme_group:str = None,
domain:str = None
):
variant = None
matches = False
# make the entire argument lower case
argument = [item.lower() for item in argument]
argument = self.de_paraphrase(argument, domain=domain) if domain else argument
for nl_scheme_re in self.nl_schemes_re:
if nl_scheme_re['base_scheme_group'] == base_scheme_group:
matches = self.parse_inference_as_scheme(
argument = argument,
nl_scheme_re = nl_scheme_re
)
if matches:
variant = nl_scheme_re['scheme_variant']
break
return matches, variant
def de_paraphrase(self, argument:list = None, domain:str = None):
rules = {}
if domain in self.de_paraphrasing_rules:
rules = self.de_paraphrasing_rules[domain]
for i,statement in enumerate(argument):
s = statement
for k,v in rules.items():
s = s.replace(k,v)
argument[i] = s
return argument
def parse_string_formula(self, formula:str):
atom = pp.Regex("[A-Z]\s[a-u|w-z]").setName("atom")
expr = pp.infixNotation(atom,[
("not", 1, pp.opAssoc.RIGHT, ),
("&", 2, pp.opAssoc.LEFT, ),
("v", 2, pp.opAssoc.LEFT, ),
("->", 2, pp.opAssoc.LEFT, ),
("<->", 2, pp.opAssoc.LEFT, )
])
try:
parsed = expr.parseString(formula,parseAll=True)[0]
except pp.ParseException as e:
parsed = None
return parsed
def c_bf(self,parse_tree):
if not parse_tree:
return None
functions = {}
constants = {}
Object = z3.DeclareSort('Object')
bin_op = {
"&": z3.And,
"v": z3.Or,
"->": z3.Implies
}
pt = parse_tree
if pt[0]=="not":
return z3.Not(self.c_bf(pt[1]))
if pt[1]=="<->":
return z3.And(z3.Implies(self.c_bf(pt[0]),self.c_bf(pt[2])),z3.Implies(self.c_bf(pt[2]),self.c_bf(pt[0])))
if pt[1] in bin_op.keys():
return bin_op[pt[1]](self.c_bf(pt[0]),self.c_bf(pt[2]))
# atom
pred = parse_tree[0]
if not pred in functions.keys(): # add predicate to dict if necessary
functions[pred] = z3.Function(pred, Object, z3.BoolSort())
const = parse_tree[-1]
if not const in constants.keys(): # add function to dict if necessary
functions[const] = z3.Const(const, Object)
return functions[pred](functions[const])
def to_z3(self,str_f:str):
if not str_f:
return None
f = str_f.strip()
if f[:4] == "(x):":
Object = z3.DeclareSort('Object')
x = z3.Const('x', Object)
parsed = self.parse_string_formula(f[4:].strip())
if parsed:
return (z3.ForAll(x,self.c_bf(parsed)))
return None
parsed = self.parse_string_formula(f)
if parsed:
return self.c_bf(parsed)
return None
def check_deductive_validity(self,scheme:list):
premises = [self.to_z3(f) for f in scheme[:-1]] # premises
conclusion = self.to_z3(scheme[-1]) # conclusion
#print(theory)
if any(p==None for p in premises) or (conclusion==None):
return None
theory = premises # premises
theory.append(z3.Not(conclusion)) # negation of conclusion
s = z3.Solver()
s.add(theory)
#print(s.check())
valid = s.check()==z3.unsat
return valid
#######################################
# Main eval function #
#######################################
def ad_valid_syntax(argdown):
check = False
if argdown:
# consecutive labeling?
check = all(p['label']==i+1 for i,p in enumerate(argdown))
# no "--" as statement
check = check and all(item['text']!="--" for item in argdown)
return 1 if check else 0
def ad_last_st_concl(argdown):
check = len(argdown[-1]['uses'])>0 if argdown else False
return 1 if check else 0
# do all statements referenced in inference exist and do they occur before inference is drawn?
def used_prem_exist(argdown):
if not argdown: return None
check = True
previous_labels = []
for p in argdown:
if p['uses']:
exist = all((l in previous_labels) for l in p['uses'])
check = check and exist
if not check: return 0
previous_labels.append(p['label'])
return 1 if check else 0
def ad_redundant_prem(argdown):
if not argdown: return None
premises = [item['text'].strip() for item in argdown if not item['uses']]
check = len(premises)!=len(set(premises))
return 1 if check else 0
def ad_petitio(argdown):
if not argdown: return None
premises = [item['text'].strip() for item in argdown if not item['uses']]
conclusions = [item['text'].strip() for item in argdown if item['uses']]
check = any(p==c for p in premises for c in conclusions)
return 1 if check else 0
def prem_non_used(argdown):
if not argdown: return None
used_total = [l for p in argdown if p['uses'] for l in p['uses']]
non_used = [p['label'] for p in argdown if not p['label'] in used_total]
return len(non_used)-1
#######################################
# Evaluating reason_statements #
#######################################
def s_valid_syntax(output:list,raw:str=""):
return 1 if output!=None or raw.strip()=="None" else 0
def s_not_verb_quotes(output:list,source):
l = [s for s in output if not s['text'] in source]
return len(l)
def s_ord_me_subsseq(output:list,source):
source = source
check = True
for reason in output:
text = reason['text']
check = check and (text in source)
if not check: return check
source = source.replace(text,"",1)
return check
#######################################
# Evaluating r-c-a consistency #
#######################################
# test: no reason statements is contained in a conclusion statement and vice versa
def reason_concl_mutually_exclusive(reasons,concl):
check = True
for r in reasons:
for c in concl:
check = bool(not c['text'] in r['text'] and not r['text'] in c['text'])
if not check: return check
return check