File size: 2,097 Bytes
828458d ff5f71b 78588de 828458d 8c4485d ff5f71b 828458d 575d1cf 828458d 6d7ff83 575d1cf b1e6575 0658988 b1e6575 6d7ff83 828458d c4c7f48 828458d 575d1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from pathlib import Path
import json
from tqdm import tqdm
from transformers import AutoModelForCausalLM
def get_num_parameters(model_name: str) -> int:
return AutoModelForCausalLM.from_pretrained(model_name).num_parameters()
def main():
evals_dir = Path(__file__).parent.joinpath("evals")
pf_overview = evals_dir.joinpath("models.json")
results = json.loads(pf_overview.read_text(encoding="utf-8")) if pf_overview.exists() else {}
for pfin in tqdm(list(evals_dir.rglob("*.json")), desc="Generating overview JSON"):
if pfin.stem == "models":
continue
short_name = pfin.stem.split("_", 2)[2].lower()
data = json.loads(pfin.read_text(encoding="utf-8"))
if "config" not in data:
continue
config = data["config"]
if "model_args" not in config:
continue
model_args = dict(params.split("=") for params in config["model_args"].split(","))
if "pretrained" not in model_args:
continue
results[short_name] = {
"model_name": model_args["pretrained"],
"compute_dtype": model_args.get("dtype", None),
"quantization": None,
"num_parameters": results[short_name]["num_parameters"]
if short_name in results and "num_parameters" in results[short_name]
else get_num_parameters(model_args["pretrained"]),
"model_type": results[short_name]["model_type"]
if short_name in results and "model_type" in results[short_name]
else "not-given",
"dutch_coverage": results[short_name]["dutch_coverage"]
if short_name in results and "dutch_coverage" in results[short_name]
else "not-given",
}
if "load_in_8bit" in model_args:
results[short_name]["quantization"] = "8-bit"
elif "load_in_4bit" in model_args:
results[short_name]["quantization"] = "4-bit"
pf_overview.write_text(json.dumps(results, indent=4, sort_keys=True), encoding="utf-8")
if __name__ == "__main__":
main()
|