Spaces:
Running
Running
import argparse | |
import json | |
from tqdm import tqdm | |
import torch | |
from transformers import AutoTokenizer, LlamaTokenizer | |
from gptwm import GPTWatermarkDetector | |
def main(args): | |
with open(args.input_file, 'r') as f: | |
data = [json.loads(x) for x in f.read().strip().split("\n")] | |
if 'llama' in args.model_name: | |
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16) | |
vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size | |
detector = GPTWatermarkDetector(fraction=args.fraction, | |
strength=args.strength, | |
vocab_size=vocab_size, | |
watermark_key=args.wm_key) | |
z_score_list = [] | |
for idx, cur_data in tqdm(enumerate(data), total=len(data)): | |
gen_tokens = tokenizer(cur_data['gen_completion'][0], add_special_tokens=False)["input_ids"] | |
if len(gen_tokens) >= args.test_min_tokens: | |
z_score_list.append(detector.detect(gen_tokens)) | |
else: | |
print(f"Warning: sequence {idx} is too short to test.") | |
save_dict = { | |
'z_score': z_score_list, | |
'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list] | |
} | |
print(save_dict) | |
with open(args.input_file.replace('.jsonl', '_z.jsonl'), 'w') as f: | |
json.dump(save_dict, f) | |
print('Finished!') | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# parser.add_argument("--model_name", type=str, default="facebook/opt-125m") | |
parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf") | |
parser.add_argument("--fraction", type=float, default=0.5) | |
parser.add_argument("--strength", type=float, default=2.0) | |
parser.add_argument("--threshold", type=float, default=6.0) | |
parser.add_argument("--wm_key", type=int, default=0) | |
parser.add_argument("--input_file", type=str, default="./data/example_output.jsonl") | |
parser.add_argument("--test_min_tokens", type=int, default=200) | |
args = parser.parse_args() | |
main(args) | |