turn-the-cam-anonymous's picture
adding CLIP taming
1ed7deb
raw
history blame
549 Bytes
import torch
import sys
if __name__ == "__main__":
inpath = sys.argv[1]
outpath = sys.argv[2]
submodel = "cond_stage_model"
if len(sys.argv) > 3:
submodel = sys.argv[3]
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
sd = torch.load(inpath, map_location="cpu")
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
for k,v in sd["state_dict"].items()
if k.startswith("cond_stage_model"))}
torch.save(new_sd, outpath)