| | |
| | |
| | import os |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def cal_cross_attn(to_q, to_k, to_v, rand_input): |
| | hidden_dim, embed_dim = to_q.shape |
| | attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False) |
| | attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False) |
| | attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False) |
| | attn_to_q.load_state_dict({"weight": to_q}) |
| | attn_to_k.load_state_dict({"weight": to_k}) |
| | attn_to_v.load_state_dict({"weight": to_v}) |
| |
|
| | return torch.einsum( |
| | "ik, jk -> ik", |
| | F.softmax( |
| | torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), |
| | dim=-1, |
| | ), |
| | attn_to_v(rand_input), |
| | ) |
| |
|
| |
|
| | def model_hash(filename): |
| | try: |
| | with open(filename, "rb") as file: |
| | import hashlib |
| |
|
| | m = hashlib.sha256() |
| |
|
| | file.seek(0x100000) |
| | m.update(file.read(0x10000)) |
| | return m.hexdigest()[0:8] |
| | except FileNotFoundError: |
| | return "NOFILE" |
| |
|
| |
|
| | def eval(model, n, input): |
| | qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight" |
| | uk = f"enc_p.encoder.attn_layers.{n}.conv_k.weight" |
| | vk = f"enc_p.encoder.attn_layers.{n}.conv_v.weight" |
| | atoq, atok, atov = model[qk][:, :, 0], model[uk][:, :, 0], model[vk][:, :, 0] |
| |
|
| | attn = cal_cross_attn(atoq, atok, atov, input) |
| | return attn |
| |
|
| |
|
| | def main(path, root): |
| | torch.manual_seed(114514) |
| | model_a = torch.load(path, map_location="cpu")["weight"] |
| |
|
| | logger.info("Query:\t\t%s\t%s" % (path, model_hash(path))) |
| |
|
| | map_attn_a = {} |
| | map_rand_input = {} |
| | for n in range(6): |
| | hidden_dim, embed_dim, _ = model_a[ |
| | f"enc_p.encoder.attn_layers.{n}.conv_v.weight" |
| | ].shape |
| | rand_input = torch.randn([embed_dim, hidden_dim]) |
| |
|
| | map_attn_a[n] = eval(model_a, n, rand_input) |
| | map_rand_input[n] = rand_input |
| |
|
| | del model_a |
| |
|
| | for name in sorted(list(os.listdir(root))): |
| | path = "%s/%s" % (root, name) |
| | model_b = torch.load(path, map_location="cpu")["weight"] |
| |
|
| | sims = [] |
| | for n in range(6): |
| | attn_a = map_attn_a[n] |
| | attn_b = eval(model_b, n, map_rand_input[n]) |
| |
|
| | sim = torch.mean(torch.cosine_similarity(attn_a, attn_b)) |
| | sims.append(sim) |
| |
|
| | logger.info( |
| | "Reference:\t%s\t%s\t%s" |
| | % (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%") |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | query_path = r"assets\weights\mi v3.pth" |
| | reference_root = r"assets\weights" |
| | main(query_path, reference_root) |
| |
|