Unigram-Watermark / run_detect.py
Xuandong's picture
init
6a20eb3
raw
history blame
No virus
2.26 kB
import argparse
import json
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, LlamaTokenizer
from gptwm import GPTWatermarkDetector
def main(args):
with open(args.input_file, 'r') as f:
data = [json.loads(x) for x in f.read().strip().split("\n")]
if 'llama' in args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
detector = GPTWatermarkDetector(fraction=args.fraction,
strength=args.strength,
vocab_size=vocab_size,
watermark_key=args.wm_key)
z_score_list = []
for idx, cur_data in tqdm(enumerate(data), total=len(data)):
gen_tokens = tokenizer(cur_data['gen_completion'][0], add_special_tokens=False)["input_ids"]
if len(gen_tokens) >= args.test_min_tokens:
z_score_list.append(detector.detect(gen_tokens))
else:
print(f"Warning: sequence {idx} is too short to test.")
save_dict = {
'z_score': z_score_list,
'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list]
}
print(save_dict)
with open(args.input_file.replace('.jsonl', '_z.jsonl'), 'w') as f:
json.dump(save_dict, f)
print('Finished!')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
parser.add_argument("--fraction", type=float, default=0.5)
parser.add_argument("--strength", type=float, default=2.0)
parser.add_argument("--threshold", type=float, default=6.0)
parser.add_argument("--wm_key", type=int, default=0)
parser.add_argument("--input_file", type=str, default="./data/example_output.jsonl")
parser.add_argument("--test_min_tokens", type=int, default=200)
args = parser.parse_args()
main(args)