Upload xor_codec.py
Browse files- xor_codec.py +92 -0
xor_codec.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from OpenAssistant's original xor_codec.py:
|
3 |
+
https://huggingface.co/OpenAssistant/oasst-sft-6-llama-30b-xor/raw/main/xor_codec.py
|
4 |
+
'''
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import shutil
|
8 |
+
import gzip
|
9 |
+
import numpy
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
def xor_uncompressed(dst, src_payload, src_base, block_size=4096):
|
13 |
+
fp_payload = open(src_payload, 'rb')
|
14 |
+
fp_base = open(src_base, 'rb')
|
15 |
+
with open(dst, 'wb') as fp:
|
16 |
+
while True:
|
17 |
+
buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
|
18 |
+
buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
|
19 |
+
padding = len(buf1) - len(buf2)
|
20 |
+
if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
|
21 |
+
if padding < 0: buf2 = buf2[:len(buf1)]
|
22 |
+
buf = numpy.bitwise_xor(buf1, buf2)
|
23 |
+
fp.write(buf)
|
24 |
+
if len(buf1) < block_size: break
|
25 |
+
fp_payload.close()
|
26 |
+
fp_base.close()
|
27 |
+
|
28 |
+
def xor_encode(dst, src_payload, src_base, block_size=4096):
|
29 |
+
fp_payload = open(src_payload, 'rb')
|
30 |
+
fp_base = open(src_base, 'rb')
|
31 |
+
with gzip.open(dst, 'wb') as fp:
|
32 |
+
while True:
|
33 |
+
buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
|
34 |
+
buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
|
35 |
+
padding = len(buf1) - len(buf2)
|
36 |
+
if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
|
37 |
+
if padding < 0: buf2 = buf2[:len(buf1)]
|
38 |
+
buf = numpy.bitwise_xor(buf1, buf2)
|
39 |
+
fp.write(buf)
|
40 |
+
if len(buf1) < block_size: break
|
41 |
+
fp_payload.close()
|
42 |
+
fp_base.close()
|
43 |
+
|
44 |
+
def xor_decode(dst, src_payload, src_base, block_size=4096):
|
45 |
+
fp_payload = gzip.open(src_payload, 'rb')
|
46 |
+
fp_base = open(src_base, 'rb')
|
47 |
+
with open(dst, 'wb') as fp:
|
48 |
+
while True:
|
49 |
+
buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
|
50 |
+
buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
|
51 |
+
padding = len(buf1) - len(buf2)
|
52 |
+
if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
|
53 |
+
if padding < 0: buf2 = buf2[:len(buf1)]
|
54 |
+
buf = numpy.bitwise_xor(buf1, buf2)
|
55 |
+
fp.write(buf)
|
56 |
+
if len(buf1) < block_size: break
|
57 |
+
fp_payload.close()
|
58 |
+
fp_base.close()
|
59 |
+
|
60 |
+
def xor_dir(dst, src_payload, src_base, decode=True, compress=True):
|
61 |
+
if compress:
|
62 |
+
xor = xor_decode if decode else xor_encode
|
63 |
+
else:
|
64 |
+
xor = xor_uncompressed
|
65 |
+
Path(dst).mkdir(parents=True, exist_ok=True)
|
66 |
+
for path in os.listdir(src_payload):
|
67 |
+
# Don't care about uncopyrightable text files, just copy over.
|
68 |
+
if ".json" in path:
|
69 |
+
print("[*] Copying '%s'" % path)
|
70 |
+
shutil.copy(f"{src_payload}/{path}", f"{dst}/{path}")
|
71 |
+
continue
|
72 |
+
|
73 |
+
print("[*] Processing '%s'" % path)
|
74 |
+
try:
|
75 |
+
xor("%s/%s" % (dst, path), "%s/%s" % (src_payload, path), "%s/%s" % (src_base, path))
|
76 |
+
except Exception as e:
|
77 |
+
print("Exception when processing '%s'" % path)
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
if len(sys.argv) < 4:
|
81 |
+
print("Usage: xor.py <DESTINATION> <PAYLOAD SOURCE> <LLAMA SOURCE> [--encode] [--compress]")
|
82 |
+
exit()
|
83 |
+
dst = sys.argv[1]
|
84 |
+
src_payload = sys.argv[2]
|
85 |
+
src_base = sys.argv[3]
|
86 |
+
decode = True
|
87 |
+
compress = False
|
88 |
+
if len(sys.argv) > 4:
|
89 |
+
for arg in sys.argv[4:]:
|
90 |
+
if arg == "--encode": decode = False
|
91 |
+
if arg == "--compress": compress = True
|
92 |
+
xor_dir(dst, src_payload, src_base, decode=decode, compress=compress)
|