File size: 2,258 Bytes
6a20eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import json
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, LlamaTokenizer
from gptwm import GPTWatermarkDetector


def main(args):
    with open(args.input_file, 'r') as f:
        data = [json.loads(x) for x in f.read().strip().split("\n")]
    if 'llama' in args.model_name:
        tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)

    vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size

    detector = GPTWatermarkDetector(fraction=args.fraction,
                                    strength=args.strength,
                                    vocab_size=vocab_size,
                                    watermark_key=args.wm_key)

    z_score_list = []
    for idx, cur_data in tqdm(enumerate(data), total=len(data)):
        gen_tokens = tokenizer(cur_data['gen_completion'][0], add_special_tokens=False)["input_ids"]
        if len(gen_tokens) >= args.test_min_tokens:
            z_score_list.append(detector.detect(gen_tokens))
        else:
            print(f"Warning: sequence {idx} is too short to test.")

    save_dict = {
        'z_score': z_score_list,
        'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list]
    }

    print(save_dict)
    with open(args.input_file.replace('.jsonl', '_z.jsonl'), 'w') as f:
        json.dump(save_dict, f)

    print('Finished!')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
    parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
    parser.add_argument("--fraction", type=float, default=0.5)
    parser.add_argument("--strength", type=float, default=2.0)
    parser.add_argument("--threshold", type=float, default=6.0)
    parser.add_argument("--wm_key", type=int, default=0)
    parser.add_argument("--input_file", type=str, default="./data/example_output.jsonl")
    parser.add_argument("--test_min_tokens", type=int, default=200)

    args = parser.parse_args()

    main(args)