normistral-11b-warm / convert_to_safetensors.py
davda54's picture
Upload folder using huggingface_hub
5cc6955 verified
raw
history blame
1.98 kB
import argparse
import random
from statistics import mean, stdev
from typing import List
import torch
import torchmetrics
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="/scratch/project_465000144/dasamuel/normistral/normistral-11b-masked-post-hf-60000",
help="Path to the pre-trained model",
)
args = parser.parse_args()
return args
def load_model(model_path: str):
# Load the pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(model_path, cache_dir=".", token="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", torch_dtype=torch.bfloat16).cuda().eval()
eos_token_ids = [
token_id
for token_id in range(tokenizer.vocab_size)
if "\n" in tokenizer.decode([token_id])
]
if hasattr(model.config, "n_positions"):
max_length = model.config.n_positions
elif hasattr(model.config, "max_position_embeddings"):
max_length = model.config.max_position_embeddings
elif hasattr(model.config, "max_length"):
max_length = model.config.max_length
elif hasattr(model.config, "n_ctx"):
max_length = model.config.n_ctx
else:
max_length = 4096 # Default value
return {
"name": model_path.split("/")[-1],
"tokenizer": tokenizer,
"model": model,
"eos_token_ids": eos_token_ids,
"max_length": max_length,
}
def main():
args = parse_args()
model = load_model(args.model_name_or_path)
model["model"].save_pretrained(
args.model_name_or_path,
max_shard_size="4.7GB"
)
if __name__ == "__main__":
main()