|
import argparse |
|
import random |
|
from statistics import mean, stdev |
|
from typing import List |
|
import torch |
|
import torchmetrics |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_name_or_path", |
|
type=str, |
|
default="/scratch/project_465000144/dasamuel/normistral/normistral-11b-masked-post-hf-60000", |
|
help="Path to the pre-trained model", |
|
) |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
def load_model(model_path: str): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16) |
|
model = AutoModelForCausalLM.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16).cuda().eval() |
|
|
|
eos_token_ids = [ |
|
token_id |
|
for token_id in range(tokenizer.vocab_size) |
|
if "\n" in tokenizer.decode([token_id]) |
|
] |
|
|
|
if hasattr(model.config, "n_positions"): |
|
max_length = model.config.n_positions |
|
elif hasattr(model.config, "max_position_embeddings"): |
|
max_length = model.config.max_position_embeddings |
|
elif hasattr(model.config, "max_length"): |
|
max_length = model.config.max_length |
|
elif hasattr(model.config, "n_ctx"): |
|
max_length = model.config.n_ctx |
|
else: |
|
max_length = 4096 |
|
|
|
return { |
|
"name": model_path.split("/")[-1], |
|
"tokenizer": tokenizer, |
|
"model": model, |
|
"eos_token_ids": eos_token_ids, |
|
"max_length": max_length, |
|
} |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
model = load_model(args.model_name_or_path) |
|
|
|
model["model"].save_pretrained( |
|
args.model_name_or_path, |
|
max_shard_size="4.7GB" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|