machineai-compiler-optimizer / compiler_env.py
sosonsong's picture
Upload compiler_env.py with huggingface_hub
94047c1 verified
import gymnasium as gym
import numpy as np
import subprocess
import time
import os
from ir_feature_extractor import extract_features
class LoopUnrollEnv(gym.Env):
def __init__(
self,
source_files=None,
repeat_runs=5,
arch: str = "x86",
clang_bin: str | None = None,
opt_bin: str | None = None,
):
super().__init__()
self.arch = arch
self.clang_bin = clang_bin or "clang"
self.opt_bin = opt_bin or "opt"
self.source_files = source_files or ["test_loop.c"]
self.repeat_runs = repeat_runs
self.action_space = gym.spaces.Discrete(6)
self.observation_space = gym.spaces.Box(
low=0.0, high=1.0, shape=(7,), dtype=np.float32
)
self.fixed_baselines = {}
self._precompute_baselines()
def _run_subprocess(self, cmd, **kwargs):
return subprocess.run(cmd, capture_output=True, **kwargs)
def _precompute_baselines(self):
print("베이스라인 사전 측정 중...")
for src in self.source_files:
bc = self._compile_to_bc(src)
if bc:
exe = self._bc_to_exe(bc)
if exe:
t = self._measure_time_robust(exe, n=11)
self.fixed_baselines[src] = t
print(f" {os.path.basename(src)}: {t*1000:.1f}ms")
print("베이스라인 측정 완료")
def _measure_time_robust(self, exe, n=11):
times = []
for _ in range(n):
t0 = time.perf_counter()
run_cmd = ["qemu-aarch64-static", exe] if self.arch == "arm64" else [exe]
r = self._run_subprocess(run_cmd)
t1 = time.perf_counter()
if r.returncode == 0:
times.append(t1 - t0)
return float(np.median(times)) if times else 999.0
def _measure_time(self, exe):
return self._measure_time_robust(exe, n=self.repeat_runs)
def _compile_to_bc(self, src):
bc = src.replace(".c", ".bc")
target_flags = ["-target", "aarch64-linux-gnu"] if self.arch == "arm64" else []
cmd = [
self.clang_bin,
"-O1",
"-emit-llvm",
"-c",
*target_flags,
src,
"-o",
bc,
]
r = self._run_subprocess(cmd)
return bc if r.returncode == 0 else None
def _apply_action(self, bc_file, action):
out = bc_file.replace(".bc", f"_act{action}.bc")
passes = {
0: "",
1: "loop-vectorize",
2: "inline,loop-vectorize",
3: "loop-unroll,loop-vectorize",
4: "inline,loop-unroll,loop-vectorize",
5: "loop-unroll",
}
p = passes[int(action)]
if p:
cmd = [self.opt_bin, f"--passes={p}", bc_file, "-o", out]
r = self._run_subprocess(cmd)
return out if r.returncode == 0 else bc_file
return bc_file
def _measure_code_size(self, bc_file):
"""ARM64용: 오브젝트 파일 크기로 성능 대리 측정 (qemu 대신)"""
obj = bc_file.replace(".bc", ".o")
cmd = [
self.clang_bin,
"-target", "aarch64-linux-gnu",
"-O1", "-c",
bc_file, "-o", obj
]
r = self._run_subprocess(cmd)
if r.returncode != 0:
return 999999
import os
return os.path.getsize(obj)
def _bc_to_exe(self, bc_file):
exe = os.path.abspath(bc_file.replace(".bc", "_exe"))
target_flags = ["-target", "aarch64-linux-gnu", "-static"] if self.arch == "arm64" else []
cmd = [
self.clang_bin,
"-O1",
*target_flags,
bc_file,
"-o",
exe,
"-lm",
]
r = self._run_subprocess(cmd)
return exe if r.returncode == 0 else None
def reset(self, seed=None, options=None):
super().reset(seed=seed)
idx = np.random.randint(len(self.source_files))
self.current_file = self.source_files[idx]
self.bc_file = self._compile_to_bc(self.current_file)
self.base_time = self.fixed_baselines.get(self.current_file, 1920 if self.arch == "arm64" else 1.0)
obs = np.array(extract_features(self.bc_file), dtype=np.float32)
return obs, {}
def step(self, action):
opt_bc = self._apply_action(self.bc_file, int(action))
exe = self._bc_to_exe(opt_bc)
new_time = self._measure_time(exe) if exe else self.base_time * 2
improvement = (self.base_time - new_time) / (self.base_time + 1e-9)
if improvement > 0.01:
reward = improvement * 20.0 + 1.0
elif improvement < -0.01:
reward = -2.0
else:
reward = -0.1
done = improvement > 0.70 or improvement < -0.50
info = {
"speedup_pct": improvement * 100,
"baseline_ms": self.base_time * 1000,
"optimized_ms": new_time * 1000,
"flags": int(action),
"arch": self.arch,
"clang_bin": self.clang_bin,
"opt_bin": self.opt_bin,
}
obs = np.array(extract_features(self.bc_file), dtype=np.float32)
return obs, reward, done, False, info