Nafaille commited on
Commit
10c3727
1 Parent(s): 36ef162

Upload xor_codec.py

Browse files
Files changed (1) hide show
  1. xor_codec.py +92 -0
xor_codec.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from OpenAssistant's original xor_codec.py:
3
+ https://huggingface.co/OpenAssistant/oasst-sft-6-llama-30b-xor/raw/main/xor_codec.py
4
+ '''
5
+ import os
6
+ import sys
7
+ import shutil
8
+ import gzip
9
+ import numpy
10
+ from pathlib import Path
11
+
12
+ def xor_uncompressed(dst, src_payload, src_base, block_size=4096):
13
+ fp_payload = open(src_payload, 'rb')
14
+ fp_base = open(src_base, 'rb')
15
+ with open(dst, 'wb') as fp:
16
+ while True:
17
+ buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
18
+ buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
19
+ padding = len(buf1) - len(buf2)
20
+ if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
21
+ if padding < 0: buf2 = buf2[:len(buf1)]
22
+ buf = numpy.bitwise_xor(buf1, buf2)
23
+ fp.write(buf)
24
+ if len(buf1) < block_size: break
25
+ fp_payload.close()
26
+ fp_base.close()
27
+
28
+ def xor_encode(dst, src_payload, src_base, block_size=4096):
29
+ fp_payload = open(src_payload, 'rb')
30
+ fp_base = open(src_base, 'rb')
31
+ with gzip.open(dst, 'wb') as fp:
32
+ while True:
33
+ buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
34
+ buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
35
+ padding = len(buf1) - len(buf2)
36
+ if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
37
+ if padding < 0: buf2 = buf2[:len(buf1)]
38
+ buf = numpy.bitwise_xor(buf1, buf2)
39
+ fp.write(buf)
40
+ if len(buf1) < block_size: break
41
+ fp_payload.close()
42
+ fp_base.close()
43
+
44
+ def xor_decode(dst, src_payload, src_base, block_size=4096):
45
+ fp_payload = gzip.open(src_payload, 'rb')
46
+ fp_base = open(src_base, 'rb')
47
+ with open(dst, 'wb') as fp:
48
+ while True:
49
+ buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
50
+ buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
51
+ padding = len(buf1) - len(buf2)
52
+ if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
53
+ if padding < 0: buf2 = buf2[:len(buf1)]
54
+ buf = numpy.bitwise_xor(buf1, buf2)
55
+ fp.write(buf)
56
+ if len(buf1) < block_size: break
57
+ fp_payload.close()
58
+ fp_base.close()
59
+
60
+ def xor_dir(dst, src_payload, src_base, decode=True, compress=True):
61
+ if compress:
62
+ xor = xor_decode if decode else xor_encode
63
+ else:
64
+ xor = xor_uncompressed
65
+ Path(dst).mkdir(parents=True, exist_ok=True)
66
+ for path in os.listdir(src_payload):
67
+ # Don't care about uncopyrightable text files, just copy over.
68
+ if ".json" in path:
69
+ print("[*] Copying '%s'" % path)
70
+ shutil.copy(f"{src_payload}/{path}", f"{dst}/{path}")
71
+ continue
72
+
73
+ print("[*] Processing '%s'" % path)
74
+ try:
75
+ xor("%s/%s" % (dst, path), "%s/%s" % (src_payload, path), "%s/%s" % (src_base, path))
76
+ except Exception as e:
77
+ print("Exception when processing '%s'" % path)
78
+
79
+ if __name__ == "__main__":
80
+ if len(sys.argv) < 4:
81
+ print("Usage: xor.py <DESTINATION> <PAYLOAD SOURCE> <LLAMA SOURCE> [--encode] [--compress]")
82
+ exit()
83
+ dst = sys.argv[1]
84
+ src_payload = sys.argv[2]
85
+ src_base = sys.argv[3]
86
+ decode = True
87
+ compress = False
88
+ if len(sys.argv) > 4:
89
+ for arg in sys.argv[4:]:
90
+ if arg == "--encode": decode = False
91
+ if arg == "--compress": compress = True
92
+ xor_dir(dst, src_payload, src_base, decode=decode, compress=compress)