Xueqing Wu
init
e20ef71
import ast
import importlib
import io
import os
import re
import string
import time
from functools import partial
from typing import List
import pysnooper
FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:"
EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \
' llm_query, bool_to_yesno, distance, best_image_match):'
class CompileTimeError:
pass
class ProgramRuntimeError:
pass
def process_trace(text, function_head, execution_function_head):
def remove_indent(lines):
n_space = 0
for i, c in enumerate(lines[0]):
if c == ' ':
n_space += 1
else:
break
return [line[n_space:] if line[0] == ' ' else line for line in lines]
def remove_pre_context(lines: List[str]): # lol, just a random use of List
for i in range(len(lines) - 1, -1, -1):
line = lines[i]
if execution_function_head in line:
# assert "call" in line # TODO: further double-check?
content = [line.replace(execution_function_head, function_head)] + lines[i + 1:]
if line[0] == ' ':
return remove_indent(content)
else:
return content
return []
def remove_post_context(lines):
for i, line in enumerate(lines):
if line.startswith("Source path:") and line.endswith(__file__):
return lines[:i]
elif line.startswith("Elapsed time"):
return lines[:i]
return lines
def remove_timestamp(lines):
ret = []
for line in lines:
if len(line) > 0 and line[0] in string.digits:
line = line[16:] # remove timestamp
ret.append(line)
return ret
def remove_tensor(line):
return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line)
lines = text.splitlines()
lines = remove_pre_context(lines)
lines = remove_post_context(lines)
lines = remove_timestamp(lines)
lines = [remove_tensor(line) for line in lines]
return '\n'.join(lines)
cnt = 0
def run_program_with_trace(code, image, input_type_, output_type_):
from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno
function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)
execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)
code = str(code)
if code.startswith("\ndef"):
code = code[1:] # TODO: just a temporary fix
if code.startswith('def'):
if code.startswith(function_head):
code = code.replace(function_head, '')
else:
print("--- Code with invalid format\n")
print(code)
code = execution_function_head + code
try:
code = ast.unparse(ast.parse(code))
except:
return None, CompileTimeError(), None
global cnt
cnt += 1
name = f'x{cnt}'
with open(f'{name}.py', 'w') as f:
f.write(code)
for _ in range(20):
try:
x = importlib.import_module(name)
except ModuleNotFoundError:
print("Errrr, import error. Wait a bit while.")
time.sleep(60) # I have no idea why it sometimes fails. Probably file system error
except Exception as e:
print("Import has error:", e)
break
else:
break
queues = [None, None]
image_patch_partial = partial(ImagePatch, queues=queues)
video_segment_partial = None
llm_query_partial = partial(llm_query, queues=queues)
# signal.signal(signal.SIGALRM, handler) # unfortunately doesn't work
# signal.alarm(60 * 20) # timeout = 10min, just in case while True
with io.StringIO() as f:
with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000):
result = None
error = None
try:
result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial,
llm_query_partial, bool_to_yesno, distance, best_image_match)
except:
error = ProgramRuntimeError()
# finally:
# signal.alarm(0)
os.remove(f'{name}.py')
f.seek(0)
traced = f.read(100000)
traced_processed = process_trace(traced, function_head, execution_function_head)
return result, error, traced_processed