| | import json |
| | import os |
| | import ast |
| | import subprocess |
| | from random import randint |
| | from tqdm import tqdm |
| | import signal |
| | from multiprocessing import Pool, Lock, Value |
| | from dataloaders.ProblemState import ProblemState |
| | from dataloaders.TB_eval.utils import code_call_exec_success_allclose, code_kernel_profiling |
| | import re |
| | from tb_eval.evaluators.interface import get_evaluators |
| |
|
| | class TritonBench: |
| | def __init__(self, |
| | statis_path, |
| | py_folder, |
| | instruction_path, |
| | golden_metrics, |
| | py_interpreter, |
| | perf_ref_folder, |
| | perf_G_path, |
| | result_path=None |
| | ): |
| | self.statis_path = statis_path |
| | self.py_folder = py_folder |
| | self.instruction_path = instruction_path |
| | self.golden_metrics_folder = golden_metrics |
| | self.py_interpreter = py_interpreter |
| | self.perf_ref_folder = perf_ref_folder |
| | self.perf_G_path = perf_G_path |
| | self.result_path = result_path |
| |
|
| | self.problem_states = self.load_ps(result_path) |
| | self.evaluator = get_evaluators['tbg']() |
| | |
| | def load_ps(self, path): |
| | problem_states = [] |
| | if path is None: |
| | with open(self.instruction_path, "r", encoding='utf-8') as file: |
| | instructions = json.load(file) |
| | statis_data = json.loads(open(self.statis_path, 'r', encoding='utf-8').read()) |
| |
|
| | for line in instructions: |
| | instruction = line["instruction"] |
| | label = line["output"] |
| |
|
| | |
| | g = label.replace("<|im_end|>", "").replace("<|EOT|>", "") |
| | tmp = False |
| | for item in statis_data: |
| | if g in item["output"]: |
| | file = item["file"] |
| | tmp = item |
| | break |
| | if tmp: |
| | statis_data.remove(tmp) |
| | elif g[50:220] == 'as tl\n\nif triton.__version__ >= "2.1.0":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录': |
| | file = "context_attn_nopad.py" |
| | path = os.path.join(self.py_folder, file) |
| | assert os.path.exists(path), f"{file} not exist!" |
| | test_code = open(path, "r", encoding="utf-8").read().split("#"*146)[-1] |
| | assert "def test_" in test_code, "" |
| |
|
| | problemstate = ProblemState(instruction=instruction, |
| | label=label, |
| | test_code=test_code, |
| | filename=file, |
| | ) |
| | |
| | problem_states.append( |
| | problemstate |
| | ) |
| | else: |
| | with open(path, 'r', encoding='utf-8') as file: |
| | for line in file.readlines(): |
| | content = json.loads(line) |
| | problem_state = ProblemState(instruction=content["instruction"], |
| | label=content["label"], |
| | filename=content["filename"], |
| | ) |
| | if "test_code" in content: |
| | problem_state.test_code = content["test_code"] |
| | if "predict" in content: |
| | problem_state.solution = content["predict"] |
| | problem_states.append(problem_state) |
| | return problem_states |
| |
|
| | def __len__(self): |
| | return len(self.problem_states) |
| | |
| | def write_file(self, file_path): |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| | with open(file_path, 'w') as f: |
| | for ps in self.problem_states: |
| | output = { |
| | "instruction": ps.instruction, |
| | "label": ps.label, |
| | "filename": ps.filename, |
| | } |
| | if ps.test_code: |
| | output["test_code"] = ps.test_code |
| | if ps.solution: |
| | output["predict"] = ps.solution |
| | output["speedup"] = ps.speedup |
| | |
| | f.write(json.dumps(output) + "\n") |
| | |
| | def test_opt_correctness(self, code, filename, tmp_dir, exe_dir="pass_exe",gpu_id=0): |
| | """ |
| | Runs a given Python script on a specified GPU. |
| | """ |
| | pass_call, pass_exe, speedup, call_stdout, call_stderr = self.evaluator(code, tmp_dir, exe_dir, filename, atol=1e-3, rtol=1e-3, custom_tests_path=None) |
| |
|
| | return pass_call, pass_exe, speedup, call_stdout, call_stderr |
| | |
| | def test_kernel_profiling(self, code, filename, tmp_dir, save_scripts=True, exe_dir="pass_exe", target_gpu=None, timeout=20*60): |
| | os.makedirs(exe_dir, exist_ok=True) |
| | profile_status, stdout_profile, stderr_profile, stdout_analyze = code_kernel_profiling(code=code, fname=filename, temp_root=tmp_dir, py_folder=self.py_folder, target_gpu=target_gpu, timeout=timeout) |
| | pass_prfiler = False |
| | if "True" in str(profile_status): |
| | pass_prfiler=True |
| | |
| | return pass_prfiler, stdout_profile, stderr_profile, stdout_analyze |
| | |
| | |
| |
|