Spaces:
Runtime error
Runtime error
| import argparse | |
| from tqdm import tqdm | |
| import json | |
| import torch | |
| import os | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LogitsProcessorList | |
| from gptwm import GPTWatermarkLogitsWarper | |
| def read_file(filename): | |
| with open(filename, "r") as f: | |
| return [json.loads(line) for line in f.read().strip().split("\n")] | |
| def write_file(filename, data): | |
| with open(filename, "a") as f: | |
| f.write("\n".join(data) + "\n") | |
| def main(args): | |
| output_file = f"{args.output_dir}/{args.model_name.replace('/', '-')}_strength_{args.strength}_frac_{args.fraction}_len_{args.max_new_tokens}_num_{args.num_test}.jsonl" | |
| if 'llama' in args.model_name: | |
| tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16) | |
| model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto') | |
| model.eval() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction, | |
| strength=args.strength, | |
| vocab_size=model.config.vocab_size, | |
| watermark_key=args.wm_key)]) | |
| data = read_file(args.prompt_file) | |
| num_cur_outputs = len(read_file(output_file)) if os.path.exists(output_file) else 0 | |
| outputs = [] | |
| for idx, cur_data in tqdm(enumerate(data), total=min(len(data), args.num_test)): | |
| if idx < num_cur_outputs or len(outputs) >= args.num_test: | |
| continue | |
| if "gold_completion" not in cur_data and 'targets' not in cur_data: | |
| continue | |
| elif "gold_completion" in cur_data: | |
| prefix = cur_data['prefix'] | |
| gold_completion = cur_data['gold_completion'] | |
| else: | |
| prefix = cur_data['prefix'] | |
| gold_completion = cur_data['targets'][0] | |
| batch = tokenizer(prefix, truncation=True, return_tensors="pt").to(device) | |
| num_tokens = len(batch['input_ids'][0]) | |
| with torch.inference_mode(): | |
| generate_args = { | |
| **batch, | |
| 'logits_processor': watermark_processor, | |
| 'output_scores': True, | |
| 'return_dict_in_generate': True, | |
| 'max_new_tokens': args.max_new_tokens, | |
| } | |
| if args.beam_size is not None: | |
| generate_args['num_beams'] = args.beam_size | |
| else: | |
| generate_args['do_sample'] = True | |
| generate_args['top_k'] = args.top_k | |
| generate_args['top_p'] = args.top_p | |
| generation = model.generate(**generate_args) | |
| gen_text = tokenizer.batch_decode(generation['sequences'][:, num_tokens:], skip_special_tokens=True) | |
| outputs.append(json.dumps({ | |
| "prefix": prefix, | |
| "gold_completion": gold_completion, | |
| "gen_completion": gen_text | |
| })) | |
| if (idx + 1) % 10 == 0: | |
| write_file(output_file, outputs) | |
| outputs = [] | |
| break | |
| write_file(output_file, outputs) | |
| print("Finished!") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_name", type=str, default="facebookopt-125m") | |
| # parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf") | |
| parser.add_argument("--fraction", type=float, default=0.5) | |
| parser.add_argument("--strength", type=float, default=2.0) | |
| parser.add_argument("--wm_key", type=int, default=0) | |
| parser.add_argument("--prompt_file", type=str, default="./data/LFQA/inputs.jsonl") | |
| parser.add_argument("--output_dir", type=str, default="./data/LFQA/") | |
| parser.add_argument("--max_new_tokens", type=int, default=300) | |
| parser.add_argument("--num_test", type=int, default=500) | |
| parser.add_argument("--beam_size", type=int, default=None) | |
| parser.add_argument("--top_k", type=int, default=None) | |
| parser.add_argument("--top_p", type=float, default=0.9) | |
| args = parser.parse_args() | |
| main(args) | |