|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
|
|
def convert(checkpoint: str, outdir: str, suffix: str = "base"): |
|
"""Convert the checkpoint to generator and detector""" |
|
outdir_path = Path(outdir) |
|
ckpt = torch.load(checkpoint) |
|
|
|
|
|
infer_cfg = { |
|
"seanet": ckpt["xp.cfg"]["seanet"], |
|
"channels": ckpt["xp.cfg"]["channels"], |
|
"dtype": ckpt["xp.cfg"]["dtype"], |
|
"sample_rate": ckpt["xp.cfg"]["sample_rate"], |
|
} |
|
|
|
generator_ckpt = {"xp.cfg": infer_cfg, "model": {}} |
|
detector_ckpt = {"xp.cfg": infer_cfg, "model": {}} |
|
|
|
for layer in ckpt["model"].keys(): |
|
if layer.startswith("detector"): |
|
new_layer = layer[9:] |
|
detector_ckpt["model"][new_layer] = ckpt["model"][layer] |
|
elif layer == "msg_processor.msg_processor.0.weight": |
|
generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ |
|
"model" |
|
][ |
|
layer |
|
] |
|
else: |
|
assert layer.startswith("generator"), f"Invalid layer: {layer}" |
|
new_layer = layer[10:] |
|
generator_ckpt["model"][new_layer] = ckpt["model"][layer] |
|
|
|
torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth")) |
|
torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth")) |
|
|
|
|
|
if __name__ == "__main__": |
|
import fire |
|
|
|
fire.Fire(convert) |
|
|