|
import torch |
|
import os |
|
from collections import defaultdict |
|
|
|
|
|
def parse_proxy(fname, scale): |
|
f = open(fname, 'r') |
|
layer_dict = {} |
|
for line in f: |
|
if 'proxy error' in line: |
|
line = line.rstrip() |
|
line = line[line.find('layer'):] |
|
proxy_error = float(line[line.find(':') + 1:]) |
|
layer_name = ' '.join(line.split(' ')[1:3]) |
|
layer_dict[layer_name] = {scale: proxy_error} |
|
return layer_dict |
|
|
|
total = None |
|
files = ['075', '080', '085', '090', '095', '100', '103', '105'] |
|
for key in files: |
|
res = parse_proxy(f'/work/albert/two_bit_quant/slurm_out/e8p_s{key}.log', key) |
|
if total is None: |
|
total = res |
|
else: |
|
for key in res: |
|
total[key].update(res[key]) |
|
|
|
hist = defaultdict(int) |
|
best_layer = {} |
|
for layer in total: |
|
best = float('inf') |
|
best_scale = None |
|
for scale in total[layer]: |
|
if total[layer][scale] < best: |
|
best = total[layer][scale] |
|
best_scale = scale |
|
best_layer[layer] = best_scale |
|
hist[best_scale] += 1 |
|
|
|
print(hist) |
|
exit() |
|
|
|
ckpt_path = '/work/albert/two_bit_quant/checkpoints' |
|
out_path = os.path.join(ckpt_path, 'e8p_best_scale') |
|
os.system(f'rm -rf {out_path}') |
|
os.system(f'mkdir {out_path}') |
|
|
|
os.system('cp {} {}'.format( |
|
os.path.join(ckpt_path, f'e8p_s{files[0]}', 'config.pt'), |
|
out_path)) |
|
|
|
for layer in best_layer: |
|
src = os.path.join(ckpt_path, f'e8p_s{best_layer[layer]}', '{}.pt'.format(layer.replace(' ', '_'))) |
|
tgt = out_path |
|
os.system(f'cp {src} {tgt}') |
|
|
|
|