File size: 5,972 Bytes
c99a3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""实现其他 PRF 函数(这些函数的不同之处仅在于如何从上下文中的令牌生成单个哈希值)。

可作为修改后的基类 WatermarkBase 挂接到现有的 WatermarkLogitsProcessor 中,请参见
extended_watermark_processor.py 中的实现。
"""



import torch
from itertools import combinations
from functools import cache

# 哈希方案的关键属性
props = {
    "prf_type": str,  # 基础 PRF 的字符串名称,将多个令牌 ID 映射到随机种子
    "context_width": int,  # 这是论文中的 h,每个 PRF 应考虑多少个先前的令牌
    "self_salt": bool,  # 根据鲁棒水印技术中的规则,是否使用令牌本身来生成种子,并可能拒绝其自身的列表
    "hash_key": int,  # 整数,大质数,用于将种子移动到上述所选 PRF 中的低熵位序列的远离位置
}


def seeding_scheme_lookup(seeding_scheme: str):
    if not isinstance(seeding_scheme, str):
        raise ValueError("Seeding scheme should be a string summarizing the procedure.")
    if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
        # 默认的简单二元哈希  # 别名为 ff-additive_prf-1-False-15485863
        prf_type = "additive_prf"
        context_width = 1
        self_salt = False
        hash_key = 15485863
    elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
        prf_type = "anchored_minhash_prf"
        context_width = 4
        self_salt = True
        hash_key = 15485863
    elif seeding_scheme == "minhash":
        prf_type = "minhash_prf"
        context_width = 4
        self_salt = False
        hash_key = 15485863
    elif seeding_scheme == "skipgram":
        prf_type = "skipgram_prf"
        context_width = 5
        self_salt = False
        hash_key = 15485863
    elif seeding_scheme.startswith("ff"):  # 自由形式的种子方案 API - 仅用于实验目的
        # 期望形式为 ff-additive_prf-4-True-hash 或 ff-additive_prf-5-True (哈希键是可选的)
        split_scheme = seeding_scheme.split("-")
        prf_type = str(split_scheme[1])
        context_width = int(split_scheme[2])
        self_salt = split_scheme[3] == "True"
        if len(split_scheme) == 5:
            hash_key = int(split_scheme[4])
        else:
            hash_key = 15485863
    else:
        raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try  'simple_1'?")

    assert prf_type in prf_lookup.keys()
    return prf_type, context_width, self_salt, hash_key


def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    return salt_key * input_ids.prod().item()


def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    return salt_key * input_ids.sum().item()


def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    # 对于非随机输入 id(如文本),这不是一个好主意
    return salt_key * input_ids.min().item()


def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
    # k是一个跳跃的距离
    return hashint(salt_key * input_ids[::k]).prod().item()


def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    # # 上下文内的最大距离跳字
    return hashint(salt_key * input_ids[0]).item()


def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
    # 上下文内的最大距离跳字
    return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()


def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    return hashint(salt_key * input_ids).min().item()


def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
    # 另一个关键是生成一个key
    return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()


def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
    # 上下文中所有跳字组合的最小值,k=2 表示所有对
    skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
    return skipgrams.prod(dim=1).min().item()


def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
    key = torch.as_tensor(salt_key, dtype=torch.long)
    for entry in input_ids:
        key *= hashint(key * entry)
        key %= 2**32
    return key.item()


def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
    return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()


prf_lookup = {
    "multiplicative_prf": multiplicative_prf,
    "additive_prf": additive_prf,
    "minfunc_prf": minfunc_prf,
    "simple_skip_prf": simple_skip_prf,
    "skipgram_prf": skipgram_prf,
    "anchored_skipgram_prf": anchored_skipgram_prf,
    "minhash_prf": minhash_prf,
    "anchored_minhash_prf": anchored_minhash_prf,
    "minskipgram_prf": minskipgram_prf,
    "noncomm_prf": noncomm_prf,
    "position_prf": position_prf,
}

# 在启动时生成全局置换表一次
rng = torch.Generator(device=torch.device("cpu"))
rng.manual_seed(2971215073)  
table_size = 1_000_003
fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng)  # 这个速度很快


def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
    
    return fixed_table[integer_tensor.cpu() % table_size] + 1  # 这里有一个小技巧,这个函数总是返回 CPU 的值


def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
    
    i = integer_tensor.to(torch.int32).clone()  # or torch.int16?
    i -= i << 6
    i ^= i >> 17
    i -= i << 9
    i ^= i << 4
    i -= i << 3
    i ^= i << 10
    i ^= i >> 15
    return i.to(torch.long)


@cache
def _hashint_avalanche_int(integer: int):
    i = integer % (2**32)
    i -= i << 6
    i ^= i >> 17
    i -= i << 9
    i ^= i << 4
    i -= i << 3
    i ^= i << 10
    i ^= i >> 15
    return i