|
import os |
|
import math |
|
import json |
|
|
|
CHUNK_SIZE = 2 * 1024**3 |
|
|
|
CHUNK_PATHS_FILE = "chunk_paths.json" |
|
|
|
def split(filepath, chunk_size=CHUNK_SIZE): |
|
basename = os.path.basename(filepath) |
|
dirname = os.path.dirname(filepath) |
|
extension = basename.split(".")[-1] |
|
|
|
filename_no_ext = basename.split(".")[-2] |
|
file_size = os.path.getsize(filepath) |
|
|
|
num_chunks = math.ceil(file_size / chunk_size) |
|
digit_count = len(str(num_chunks)) |
|
|
|
chunk_paths = [] |
|
|
|
for i in range(1, num_chunks+1): |
|
start = (i-1) * chunk_size |
|
|
|
chunk_filename = f"{filename_no_ext}-{str(i).zfill(digit_count)}-of-{str(num_chunks).zfill(digit_count)}.{extension}" |
|
split_path = os.path.join(dirname, chunk_filename) |
|
|
|
with open(filepath, "rb") as f_in: |
|
f_in.seek(start) |
|
chunk = f_in.read(chunk_size) |
|
|
|
with open(split_path, "wb") as f_out: |
|
f_out.write(chunk) |
|
|
|
chunk_paths.append(split_path) |
|
|
|
with open(CHUNK_PATHS_FILE, 'w') as f: |
|
json.dump(chunk_paths, f) |
|
|
|
return chunk_paths |
|
|
|
main_filepath = "consolidated.safetensors" |
|
chunk_paths = split(main_filepath) |