HugoVoxx's picture
Upload 96 files
be3b34d verified
raw
history blame
23.7 kB
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run DD+AR or AlphaGeometry solver.
Please refer to README.md for detailed instructions.
"""
import time
import traceback
from absl import app
from absl import flags
from absl import logging
import ddar
import graph as gh
import lm_inference as lm
import pretty as pt
import problem as pr
#=============
import sys, os, math, re
import multiprocessing
import warnings
warnings.filterwarnings("ignore")
model = None # global variable used in multi-processing workers
_GIN_SEARCH_PATHS = flags.DEFINE_list(
'gin_search_paths',
['third_party/py/meliad/transformer/configs'],
'List of paths where the Gin config files are located.',
)
_GIN_FILE = flags.DEFINE_multi_string(
'gin_file', ['base_htrans.gin'], 'List of Gin config files.'
)
_GIN_PARAM = flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.'
)
_PROBLEMS_FILE = flags.DEFINE_string(
'problems_file',
'imo_ag_30.txt',
'text file contains the problem strings. See imo_ag_30.txt for example.',
)
_PROBLEM_NAME = flags.DEFINE_string(
'problem_name',
'imo_2000_p1',
'name of the problem to solve, must be in the problem_file.',
)
_MODE = flags.DEFINE_string(
'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`')
_DEFS_FILE = flags.DEFINE_string(
'defs_file',
'defs.txt',
'definitions of available constructions to state a problem.',
)
_RULES_FILE = flags.DEFINE_string(
'rules_file', 'rules.txt', 'list of deduction rules used by DD.'
)
_CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.')
_VOCAB_PATH = flags.DEFINE_string(
'vocab_path', '', 'path to the LM vocab file.'
)
_OUT_FILE = flags.DEFINE_string(
'out_file', '', 'path to the solution output file.'
) # pylint: disable=line-too-long
_BEAM_SIZE = flags.DEFINE_integer(
'beam_size', 1, 'beam size of the proof search.'
) # pylint: disable=line-too-long
_SEARCH_DEPTH = flags.DEFINE_integer(
'search_depth', 1, 'search depth of the proof search.'
) # pylint: disable=line-too-long
#===================================
_N_WORKSERS = flags.DEFINE_integer(
'n_workers', 1, 'number of workers'
)# pylint: disable=line-too-long
DEFINITIONS = None # contains definitions of construction actions
RULES = None # contains rules of deductions
def natural_language_statement(logical_statement: pr.Dependency) -> str:
"""Convert logical_statement to natural language.
Args:
logical_statement: pr.Dependency with .name and .args
Returns:
a string of (pseudo) natural language of the predicate for human reader.
"""
names = [a.name.upper() for a in logical_statement.args]
names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names]
return pt.pretty_nl(logical_statement.name, names)
def proof_step_string(
proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool
) -> str:
"""Translate proof to natural language.
Args:
proof_step: pr.Dependency with .name and .args
refs: dict(hash: int) to keep track of derived predicates
last_step: boolean to keep track whether this is the last step.
Returns:
a string of (pseudo) natural language of the proof step for human reader.
"""
premises, [conclusion] = proof_step
premises_nl = ' & '.join(
[
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
for p in premises
]
)
if not premises:
premises_nl = 'similarly'
refs[conclusion.hashed()] = len(refs)
conclusion_nl = natural_language_statement(conclusion)
if not last_step:
conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()])
return f'{premises_nl} \u21d2 {conclusion_nl}'
def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
"""Output the solution to out_file.
Args:
g: gh.Graph object, containing the proof state.
p: pr.Problem object, containing the theorem.
out_file: file to write to, empty string to skip writing to file.
"""
setup, aux, proof_steps, refs = ddar.get_proof_steps(
g, p.goal, merge_trivials=False
)
solution = ''
solution += 'Theo đề bài ta có:\n'
premises_nl = []
for premises, [points] in setup:
solution += ' '.join([p.name.upper() for p in points]) + ' '
if not premises:
continue
premises_nl += [
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
for p in premises
]
solution += ': Points\n' + '\n'.join(premises_nl)
solution += '\n\nCác điểm cần dựng thêm:\n'
aux_premises_nl = []
if len(aux) == 0:
solution += 'Không cần dựng thêm điểm nào.'
else:
for premises, [points] in aux:
solution += ' '.join([p.name.upper() for p in points]) + ' '
aux_premises_nl += [
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
for p in premises
]
solution += ': Points\n' + '\n'.join(aux_premises_nl)
# some special case where the deduction rule has a well known name.
r2name = {
'r32': '(SSS)',
'r33': '(SAS)',
'r34': '(Similar Triangles)',
'r35': '(Similar Triangles)',
'r36': '(ASA)',
'r37': '(ASA)',
'r38': '(Similar Triangles)',
'r39': '(Similar Triangles)',
'r40': '(Congruent Triangles)',
'a00': '(Distance chase)',
'a01': '(Ratio chase)',
'a02': '(Angle chase)',
}
solution += '\n\nCác bước chứng minh:\n'
for i, step in enumerate(proof_steps):
_, [con] = step
nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
rule_name = r2name.get(con.rule_name, '')
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
solution += '{:03}. '.format(i + 1) + nl + '\n'
logging.info(solution)
if out_file:
with open(out_file, 'w') as f:
f.write(solution)
logging.info('Solution written to %s.', out_file)
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
lm.parse_gin_configuration(
_GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
)
return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
"""Run DD+AR.
Args:
g: gh.Graph object, containing the proof state.
p: pr.Problem object, containing the problem statement.
out_file: path to output file if solution is found.
Returns:
Boolean, whether DD+AR finishes successfully.
"""
ddar.solve(g, RULES, p, max_level=1000)
goal_args = g.names2nodes(p.goal.args)
if not g.check(p.goal.name, goal_args):
logging.info('DD+AR failed to solve the problem.')
return False
write_solution(g, p, out_file)
gh.nm.draw(
g.type2nodes[gh.Point],
g.type2nodes[gh.Line],
g.type2nodes[gh.Circle],
g.type2nodes[gh.Segment],
save_to="ag4mout/output.png",)
return True
def translate_constrained_to_constructive(
point: str, name: str, args: list[str]
) -> tuple[str, list[str]]:
"""Translate a predicate from constraint-based to construction-based.
Args:
point: str: name of the new point
name: str: name of the predicate, e.g., perp, para, etc.
args: list[str]: list of predicate args.
Returns:
(name, args): translated to constructive predicate.
"""
if name in ['T', 'perp']:
a, b, c, d = args
if point in [c, d]:
a, b, c, d = c, d, a, b
if point == b:
a, b = b, a
if point == d:
c, d = d, c
if a == c and a == point:
return 'on_dia', [a, b, d]
return 'on_tline', [a, b, c, d]
elif name in ['P', 'para']:
a, b, c, d = args
if point in [c, d]:
a, b, c, d = c, d, a, b
if point == b:
a, b = b, a
return 'on_pline', [a, b, c, d]
elif name in ['D', 'cong']:
a, b, c, d = args
if point in [c, d]:
a, b, c, d = c, d, a, b
if point == b:
a, b = b, a
if point == d:
c, d = d, c
if a == c and a == point:
return 'on_bline', [a, b, d]
if b in [c, d]:
if b == d:
c, d = d, c # pylint: disable=unused-variable
return 'on_circle', [a, b, d]
return 'eqdistance', [a, b, c, d]
elif name in ['C', 'coll']:
a, b, c = args
if point == b:
a, b = b, a
if point == c:
a, b, c = c, a, b
return 'on_line', [a, b, c]
elif name in ['^', 'eqangle']:
a, b, c, d, e, f = args
if point in [d, e, f]:
a, b, c, d, e, f = d, e, f, a, b, c
x, b, y, c, d = b, c, e, d, f
if point == b:
a, b, c, d = b, a, d, c
if point == d and x == y: # x p x b = x c x p
return 'angle_bisector', [point, b, x, c]
if point == x:
return 'eqangle3', [x, a, b, y, c, d]
return 'on_aline', [a, x, b, c, y, d]
elif name in ['cyclic', 'O']:
a, b, c = [x for x in args if x != point]
return 'on_circum', [point, a, b, c]
return name, args
def check_valid_args(name: str, args: list[str]) -> bool:
"""Check whether a predicate is grammarically correct.
Args:
name: str: name of the predicate
args: list[str]: args of the predicate
Returns:
bool: whether the predicate arg count is valid.
"""
if name == 'perp':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b}) < 2:
return False
if len({c, d}) < 2:
return False
elif name == 'para':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b, c, d}) < 4:
return False
elif name == 'cong':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b}) < 2:
return False
if len({c, d}) < 2:
return False
elif name == 'coll':
if len(args) != 3:
return False
a, b, c = args
if len({a, b, c}) < 3:
return False
elif name == 'cyclic':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b, c, d}) < 4:
return False
elif name == 'eqangle':
if len(args) != 8:
return False
a, b, c, d, e, f, g, h = args
if len({a, b, c, d}) < 3:
return False
if len({e, f, g, h}) < 3:
return False
return True
def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str:
"""Whether a string of aux construction can be constructed.
Args:
string: str: the string describing aux construction.
g: gh.Graph: the current proof state.
Returns:
str: whether this construction is valid. If not, starts with "ERROR:".
"""
if string[-1] != ';':
return 'ERROR: must end with ;'
logging.info(f'PID={os.getpid()}: !! try_translate_constrained_to_construct: string=%s', string)
# sometimes the LM may return ill-formed result with multiple colons.
# example:
#
# napoleon2
# a1 a2 a3 = triangle; c3 = s_angle a1 a2 c3 30, s_angle a2 a1 c3 150; c1 = s_angle a2 a3 c1 30, s_angle a3 a2 c1 150; c2 = s_angle a3 a1 c2 30, s_angle a1 a3 c2 150 ? cong c1 c2 c1 c3
#
# in the process,
# I0210 17:58:01.513668 140016515833856 alphageometry.py:550] Decoding from {S} a : ; b : ; c : ; d : ^ a d a b 5. pi / 6. 00 ^ b d b a 1. pi / 6. 01 ; e : ^ b e b c 5. pi / 6. 02 ^ c e c b 1. pi / 6. 03 ; f : ^ a f a c 1. pi / 6. 04 ^ c f c a 5. pi / 6. 05 ? D e f e d {F1} x00 g : C a b g 06 D a g b g 07 ; x00 h : C c b h 08 D c h b h 09 ; x00
# I0210 18:01:38.182158 140016515833856 alphageometry.py:384] !! try_translate_constrained_to_construct: string=i : C a c i 10 D a i c i 11 ? V d f {F1} x00 j : D g j h j 12 D h j i j 13 ;
#XXX
# str_parts = string.split(' : ')
# if len(str_parts) != 2:
# return f'ERROR: string has multiple colons: |{string}|'
mch = re.match('(.*?)( \? | \. \{)', string)
if mch :
strFixed = mch.group(1) + ';'
logging.info(f'ID={os.getpid()}: Bad LM output: {string}. Changed to {strFixed}')
string = strFixed
# sometimes the constraint in string is empty:
# 0407 17:11:35.470240 126383800963072 alphageometry.py:394] !! try_translate_constrained_to_construct: string=j : ;
hdprem = string.split(' : ')
if len(hdprem) !=2 or hdprem[1].strip()==';' :
logging.info(f'ID={os.getpid()}: Bad LM output: {string}. ERROR')
return f'ERROR: Bad LM output: {string}'
head, prem_str = hdprem
point = head.strip()
if len(point) != 1 or point == ' ':
return f'ERROR: invalid point name {point}'
existing_points = [p.name for p in g.all_points()]
if point in existing_points:
return f'ERROR: point {point} already exists.'
prem_toks = prem_str.split()[:-1] # remove the EOS ' ;'
prems = [[]]
for i, tok in enumerate(prem_toks):
if tok.isdigit():
if i < len(prem_toks) - 1:
prems.append([])
else:
prems[-1].append(tok)
if len(prems) > 2:
return 'ERROR: there cannot be more than two predicates.'
clause_txt = point + ' = '
constructions = []
for prem in prems:
name, *args = prem
if point not in args:
return f'ERROR: {point} not found in predicate args.'
if not check_valid_args(pt.map_symbol(name), args):
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
for a in args:
if a != point and a not in existing_points:
return f'ERROR: point {a} does not exist.'
try:
name, args = translate_constrained_to_constructive(point, name, args)
except: # pylint: disable=bare-except
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
if name == 'on_aline':
if args.count(point) > 1:
return f'ERROR: on_aline involves twice {point}'
constructions += [name + ' ' + ' '.join(args)]
clause_txt += ', '.join(constructions)
clause = pr.Clause.from_txt(clause_txt)
try:
g.copy().add_clause(clause, 0, DEFINITIONS)
except: # pylint: disable=bare-except
return 'ERROR: ' + traceback.format_exc()
return clause_txt
def insert_aux_to_premise(pstring: str, auxstring: str) -> str:
"""Insert auxiliary constructs from proof to premise.
Args:
pstring: str: describing the problem to solve.
auxstring: str: describing the auxiliar construction.
Returns:
str: new pstring with auxstring inserted before the conclusion.
"""
setup, goal = pstring.split(' ? ')
return setup + '; ' + auxstring + ' ? ' + goal
class BeamQueue:
"""Keep only the top k objects according to their values."""
def __init__(self, max_size: int = 512):
self.queue = []
self.max_size = max_size
def add(self, node: object, val: float) -> None:
"""Add a new node to this queue."""
if len(self.queue) < self.max_size:
self.queue.append((val, node))
return
# Find the minimum node:
min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1])
# replace it if the new node has higher value.
if val > min_val:
self.queue[min_idx] = (val, node)
def __iter__(self):
for val, node in self.queue:
yield val, node
def __len__(self) -> int:
return len(self.queue)
#XXX
def bqsearch_init():
global model
logging.info('Worker initializing. PID=%d', os.getpid())
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
def bqsearch(i_nd, srch_inputs, out_file) -> tuple[int, bool, list]: # ( iNode, solved, [ (node, score) ] )
pid = os.getpid()
logging.info(f'Worker PID={pid} called for beam search node {i_nd}')
prev_score, (g, string, pstring) = srch_inputs
logging.info(f'Worker PID={pid}: Decoding from {string}')
outputs = model.beam_decode(string, eos_tokens=[';'])
# translate lm output to the constructive language.
# so that we can update the graph representing proof states:
translations = [
try_translate_constrained_to_construct(o, g)
for o in outputs['seqs_str']
]
# couple the lm outputs with its translations
candidates = zip(outputs['seqs_str'], translations, outputs['scores'])
# bring the highest scoring candidate first
candidates = reversed(list(candidates))
ret = []
for lm_out, translation, score in candidates:
logging.info(f'Worker PID={pid}: LM output (score={score}): "{lm_out}"')
logging.info(f'Worker PID={pid}: Translation: "{translation}"')
if translation.startswith('ERROR:'):
# the construction is invalid.
continue
# Update the constructive statement of the problem with the aux point:
candidate_pstring = insert_aux_to_premise(pstring, translation)
#XXX
logging.info(f'Worker PID={pid}: string=|{string}| lm_out=|{lm_out}|')
logging.info(f'Worker PID={pid}: Solving: "{candidate_pstring}"')
p_new = pr.Problem.from_txt(candidate_pstring)
# This is the new proof state graph representation:
g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS)
try:
if run_ddar(g_new, p_new, out_file):
logging.info('Worker PID={pid}: Solved.')
return (i_nd, True, None)
except Exception as e:
logging.info(f'Worker PID={pid}: Error in run_ddar: {e}')
# Add the candidate to the beam queue.
ret.append( [
# The string for the new node is old_string + lm output +
# the special token asking for a new auxiliary point ' x00':
# node
(g_new, string + ' ' + lm_out + ' x00', candidate_pstring),
# the score of each node is sum of score of all nodes
# on the path to itself. For beam search, there is no need to
# normalize according to path length because all nodes in beam
# is of the same path length.
# val
prev_score + score ]
)
logging.info(f'Worker PID={pid} beam search node {i_nd}: returning')
return (i_nd, False, ret)
def run_alphageometry(
#XX model: lm.LanguageModelInference,
p: pr.Problem,
search_depth: int,
beam_size: int,
out_file: str,
) -> bool:
"""Simplified code to run AlphaGeometry proof search.
We removed all optimizations that are infrastructure-dependent, e.g.
parallelized model inference on multi GPUs,
parallelized DD+AR on multiple CPUs,
parallel execution of LM and DD+AR,
shared pool of CPU workers across different problems, etc.
Many other speed optimizations and abstractions are also removed to
better present the core structure of the proof search.
Args:
model: Interface with inference-related endpoints to JAX's model.
p: pr.Problem object describing the problem to solve.
search_depth: max proof search depth.
beam_size: beam size of the proof search.
out_file: path to output file if solution is found.
Returns:
boolean of whether this is solved.
"""
# translate the problem to a string of grammar that the LM is trained on.
string = p.setup_str_from_problem(DEFINITIONS)
# special tokens prompting the LM to generate auxiliary points.
string += ' {F1} x00'
# the graph to represent the proof state.
g, _ = gh.Graph.build_problem(p, DEFINITIONS)
# First we run the symbolic engine DD+AR:
if run_ddar(g, p, out_file):
return True
# ?? when pickling graph for some problems, the default recursion limit 1000 is not enough,
# got 'maximum recursion depth exceeded while pickling an object' error
sys.setrecursionlimit(10000)
# beam search for the proof
# each node in the search tree is a 3-tuple:
# (<graph representation of proof state>,
# <string for LM to decode from>,
# <original problem string>)
beam_queue = BeamQueue(max_size=beam_size)
# originally the beam search tree starts with a single node (a 3-tuple):
beam_queue.add(
node=(g, string, p.txt()), val=0.0 # value of the root node is simply 0.
)
pool = None
if _N_WORKSERS.value == 1:
bqsearch_init()
else:
pool = multiprocessing.Pool(_N_WORKSERS.value, bqsearch_init)
for depth in range(search_depth):
logging.info(
'Depth %s. There are %i nodes to expand:', depth, len(beam_queue)
)
for _, (_, string, _) in beam_queue:
logging.info(string)
new_queue = BeamQueue(max_size=beam_size) # to replace beam_queue.
if _N_WORKSERS.value==1:
for i, srch_inputs in enumerate(beam_queue):
_, solved, res = bqsearch(i, srch_inputs, out_file)
if solved:
return True
for node, val in res:
# Add the candidate to the beam queue.
new_queue.add(node, val)
# Note that the queue only maintain at most beam_size nodes
# so this new node might possibly be dropped depending on its value.
else:
jobs = [pool.apply_async(bqsearch, (i, srch_inputs, out_file)) for i, srch_inputs in enumerate(beam_queue)]
n_done = 0
while n_done < len(beam_queue):
for i, jobres in enumerate(jobs):
if jobres and jobres.ready():
n_done += 1
jobs[i] = None
_, solved, res = jobres.get()
if solved:
# Clean up resources
pool.terminate()
pool.join()
return True
for node, val in res:
# Add the candidate to the beam queue.
new_queue.add(node, val)
# Note that the queue only maintain at most beam_size nodes
# so this new node might possibly be dropped depending on its value.
time.sleep(1) # Adjust wait time as needed
# replace the old queue with new queue before the new proof search depth.
beam_queue = new_queue
# Clean up resources
if pool:
pool.terminate()
pool.join()
return False
def main(_):
global DEFINITIONS
global RULES
# definitions of terms used in our domain-specific language.
DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True)
# load inference rules used in DD.
RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True)
# when using the language model,
# point names will be renamed to alphabetical a, b, c, d, e, ...
# instead of staying with their original names,
# in order to match the synthetic training data generation.
need_rename = _MODE.value != 'ddar'
# load problems from the problems_file,
problems = pr.Problem.from_txt_file(
_PROBLEMS_FILE.value, to_dict=True, translate=need_rename
)
if _PROBLEM_NAME.value not in problems:
raise ValueError(
f'Problem name `{_PROBLEM_NAME.value}` '
+ f'not found in `{_PROBLEMS_FILE.value}`'
)
this_problem = problems[_PROBLEM_NAME.value]
if _MODE.value == 'ddar':
g, _ = gh.Graph.build_problem(this_problem, DEFINITIONS)
run_ddar(g, this_problem, _OUT_FILE.value)
elif _MODE.value == 'alphageometry':
#XX model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
run_alphageometry(
#XX model,
this_problem,
_SEARCH_DEPTH.value,
_BEAM_SIZE.value,
_OUT_FILE.value,
)
else:
raise ValueError(f'Unknown FLAGS.mode: {_MODE.value}')
if __name__ == '__main__':
app.run(main)