import ast import importlib import io import os import re import string import time from functools import partial from typing import List import pysnooper FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:" EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \ ' llm_query, bool_to_yesno, distance, best_image_match):' class CompileTimeError: pass class ProgramRuntimeError: pass def process_trace(text, function_head, execution_function_head): def remove_indent(lines): n_space = 0 for i, c in enumerate(lines[0]): if c == ' ': n_space += 1 else: break return [line[n_space:] if line[0] == ' ' else line for line in lines] def remove_pre_context(lines: List[str]): # lol, just a random use of List for i in range(len(lines) - 1, -1, -1): line = lines[i] if execution_function_head in line: # assert "call" in line # TODO: further double-check? content = [line.replace(execution_function_head, function_head)] + lines[i + 1:] if line[0] == ' ': return remove_indent(content) else: return content return [] def remove_post_context(lines): for i, line in enumerate(lines): if line.startswith("Source path:") and line.endswith(__file__): return lines[:i] elif line.startswith("Elapsed time"): return lines[:i] return lines def remove_timestamp(lines): ret = [] for line in lines: if len(line) > 0 and line[0] in string.digits: line = line[16:] # remove timestamp ret.append(line) return ret def remove_tensor(line): return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line) lines = text.splitlines() lines = remove_pre_context(lines) lines = remove_post_context(lines) lines = remove_timestamp(lines) lines = [remove_tensor(line) for line in lines] return '\n'.join(lines) cnt = 0 def run_program_with_trace(code, image, input_type_, output_type_): from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_) code = str(code) if code.startswith("\ndef"): code = code[1:] # TODO: just a temporary fix if code.startswith('def'): if code.startswith(function_head): code = code.replace(function_head, '') else: print("--- Code with invalid format\n") print(code) code = execution_function_head + code try: code = ast.unparse(ast.parse(code)) except: return None, CompileTimeError(), None global cnt cnt += 1 name = f'x{cnt}' with open(f'{name}.py', 'w') as f: f.write(code) for _ in range(20): try: x = importlib.import_module(name) except ModuleNotFoundError: print("Errrr, import error. Wait a bit while.") time.sleep(60) # I have no idea why it sometimes fails. Probably file system error except Exception as e: print("Import has error:", e) break else: break queues = [None, None] image_patch_partial = partial(ImagePatch, queues=queues) video_segment_partial = None llm_query_partial = partial(llm_query, queues=queues) # signal.signal(signal.SIGALRM, handler) # unfortunately doesn't work # signal.alarm(60 * 20) # timeout = 10min, just in case while True with io.StringIO() as f: with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000): result = None error = None try: result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial, llm_query_partial, bool_to_yesno, distance, best_image_match) except: error = ProgramRuntimeError() # finally: # signal.alarm(0) os.remove(f'{name}.py') f.seek(0) traced = f.read(100000) traced_processed = process_trace(traced, function_head, execution_function_head) return result, error, traced_processed