File size: 198 Bytes
8c9c9c7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8


def load_x_from_safetensor(checkpoint, key):
    x_generator = {}
    for k,v in checkpoint.items():
        if key in k:
            x_generator[k.replace(key+'.', '')] = v
    return x_generator