import argparse import torch from modeling_bitnet import BitnetForCausalLM from tokenization_bitnet import BitnetTokenizer torch.set_grad_enabled(False) parser = argparse.ArgumentParser() parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) parser.add_argument("--output_path", default="./bitnet_b1_58-3B_quantized", type=str) def main(args): model = BitnetForCausalLM.from_pretrained( args.hf_path, device_map="auto", low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) model.quantize() model.save_pretrained(args.output_path, max_shard_size="5GB") print("Quantized model saved to", args.output_path) if __name__ == "__main__": args = parser.parse_args() main(args)