Unigram-Watermark / run_generate.py
XuandongZhao
init
6a20eb3
import argparse
from tqdm import tqdm
import json
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LogitsProcessorList
from gptwm import GPTWatermarkLogitsWarper
def read_file(filename):
with open(filename, "r") as f:
return [json.loads(line) for line in f.read().strip().split("\n")]
def write_file(filename, data):
with open(filename, "a") as f:
f.write("\n".join(data) + "\n")
def main(args):
output_file = f"{args.output_dir}/{args.model_name.replace('/', '-')}_strength_{args.strength}_frac_{args.fraction}_len_{args.max_new_tokens}_num_{args.num_test}.jsonl"
if 'llama' in args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto')
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
strength=args.strength,
vocab_size=model.config.vocab_size,
watermark_key=args.wm_key)])
data = read_file(args.prompt_file)
num_cur_outputs = len(read_file(output_file)) if os.path.exists(output_file) else 0
outputs = []
for idx, cur_data in tqdm(enumerate(data), total=min(len(data), args.num_test)):
if idx < num_cur_outputs or len(outputs) >= args.num_test:
continue
if "gold_completion" not in cur_data and 'targets' not in cur_data:
continue
elif "gold_completion" in cur_data:
prefix = cur_data['prefix']
gold_completion = cur_data['gold_completion']
else:
prefix = cur_data['prefix']
gold_completion = cur_data['targets'][0]
batch = tokenizer(prefix, truncation=True, return_tensors="pt").to(device)
num_tokens = len(batch['input_ids'][0])
with torch.inference_mode():
generate_args = {
**batch,
'logits_processor': watermark_processor,
'output_scores': True,
'return_dict_in_generate': True,
'max_new_tokens': args.max_new_tokens,
}
if args.beam_size is not None:
generate_args['num_beams'] = args.beam_size
else:
generate_args['do_sample'] = True
generate_args['top_k'] = args.top_k
generate_args['top_p'] = args.top_p
generation = model.generate(**generate_args)
gen_text = tokenizer.batch_decode(generation['sequences'][:, num_tokens:], skip_special_tokens=True)
outputs.append(json.dumps({
"prefix": prefix,
"gold_completion": gold_completion,
"gen_completion": gen_text
}))
if (idx + 1) % 10 == 0:
write_file(output_file, outputs)
outputs = []
break
write_file(output_file, outputs)
print("Finished!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="facebookopt-125m")
# parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
parser.add_argument("--fraction", type=float, default=0.5)
parser.add_argument("--strength", type=float, default=2.0)
parser.add_argument("--wm_key", type=int, default=0)
parser.add_argument("--prompt_file", type=str, default="./data/LFQA/inputs.jsonl")
parser.add_argument("--output_dir", type=str, default="./data/LFQA/")
parser.add_argument("--max_new_tokens", type=int, default=300)
parser.add_argument("--num_test", type=int, default=500)
parser.add_argument("--beam_size", type=int, default=None)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=0.9)
args = parser.parse_args()
main(args)