Unigram-Watermark / run_generate.py
Xuandong's picture
init
6a20eb3
raw
history blame contribute delete
No virus
4.4 kB
import argparse
from tqdm import tqdm
import json
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LogitsProcessorList
from gptwm import GPTWatermarkLogitsWarper
def read_file(filename):
with open(filename, "r") as f:
return [json.loads(line) for line in f.read().strip().split("\n")]
def write_file(filename, data):
with open(filename, "a") as f:
f.write("\n".join(data) + "\n")
def main(args):
output_file = f"{args.output_dir}/{args.model_name.replace('/', '-')}_strength_{args.strength}_frac_{args.fraction}_len_{args.max_new_tokens}_num_{args.num_test}.jsonl"
if 'llama' in args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto')
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
strength=args.strength,
vocab_size=model.config.vocab_size,
watermark_key=args.wm_key)])
data = read_file(args.prompt_file)
num_cur_outputs = len(read_file(output_file)) if os.path.exists(output_file) else 0
outputs = []
for idx, cur_data in tqdm(enumerate(data), total=min(len(data), args.num_test)):
if idx < num_cur_outputs or len(outputs) >= args.num_test:
continue
if "gold_completion" not in cur_data and 'targets' not in cur_data:
continue
elif "gold_completion" in cur_data:
prefix = cur_data['prefix']
gold_completion = cur_data['gold_completion']
else:
prefix = cur_data['prefix']
gold_completion = cur_data['targets'][0]
batch = tokenizer(prefix, truncation=True, return_tensors="pt").to(device)
num_tokens = len(batch['input_ids'][0])
with torch.inference_mode():
generate_args = {
**batch,
'logits_processor': watermark_processor,
'output_scores': True,
'return_dict_in_generate': True,
'max_new_tokens': args.max_new_tokens,
}
if args.beam_size is not None:
generate_args['num_beams'] = args.beam_size
else:
generate_args['do_sample'] = True
generate_args['top_k'] = args.top_k
generate_args['top_p'] = args.top_p
generation = model.generate(**generate_args)
gen_text = tokenizer.batch_decode(generation['sequences'][:, num_tokens:], skip_special_tokens=True)
outputs.append(json.dumps({
"prefix": prefix,
"gold_completion": gold_completion,
"gen_completion": gen_text
}))
if (idx + 1) % 10 == 0:
write_file(output_file, outputs)
outputs = []
break
write_file(output_file, outputs)
print("Finished!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="facebookopt-125m")
# parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
parser.add_argument("--fraction", type=float, default=0.5)
parser.add_argument("--strength", type=float, default=2.0)
parser.add_argument("--wm_key", type=int, default=0)
parser.add_argument("--prompt_file", type=str, default="./data/LFQA/inputs.jsonl")
parser.add_argument("--output_dir", type=str, default="./data/LFQA/")
parser.add_argument("--max_new_tokens", type=int, default=300)
parser.add_argument("--num_test", type=int, default=500)
parser.add_argument("--beam_size", type=int, default=None)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=0.9)
args = parser.parse_args()
main(args)