| import torch |
| import copy |
| import numpy as np |
| from scipy.stats import pearsonr |
| from t_cube import evaluate_model |
|
|
| def evaluate_slerp(clip_pt, sd_pt, sd_ft, dataloader, args, alpha=0.5): |
| """ |
| SLERP (spherical linear interpolation) between pretrained (pt) and fine-tuned (ft) weights. |
| alpha=0 -> pt only; alpha=1 -> ft only. |
| """ |
| model = copy.deepcopy(clip_pt) |
| merged_sd = {} |
| |
| for k in sd_pt.keys(): |
| w1 = sd_pt[k].flatten().float() |
| w2 = sd_ft[k].flatten().float() |
| |
| cos_val = torch.dot(w1, w2) / (w1.norm() * w2.norm() + 1e-8) |
| omega = torch.acos(torch.clamp(cos_val, -1+1e-6, 1-1e-6)) |
| sin_omega = torch.sin(omega) |
| if sin_omega < 1e-6: |
| w_interp = (1-alpha)*w1 + alpha*w2 |
| else: |
| w_interp = (torch.sin((1-alpha)*omega)/sin_omega)*w1 + \ |
| (torch.sin(alpha*omega)/sin_omega)*w2 |
| merged_sd[k] = w_interp.view_as(sd_pt[k]) |
| model.load_state_dict(merged_sd) |
| return evaluate_model(model, dataloader, args) |
|
|
|
|
| def evaluate_m3(clip_pt, sd_pt, sd_ft, dataloader, args): |
| """ |
| M^3 (Mixup Model Merge): sample lambda ~ Uniform(0,1) and do linear interpolation. |
| """ |
| model = copy.deepcopy(clip_pt) |
| lam = np.random.rand() |
| merged_sd = {k: lam * sd_ft[k] + (1 - lam) * sd_pt[k] |
| for k in sd_pt.keys()} |
| model.load_state_dict(merged_sd) |
| return evaluate_model(model, dataloader, args) |
|
|
|
|
| def evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, dataloader, args): |
| """ |
| Task Arithmetic: extrapolate along the ft−pt vector, i.e. 2*ft – pt. |
| """ |
| model = copy.deepcopy(clip_pt) |
| merged_sd = {k: 2 * sd_ft[k] - sd_pt[k] for k in sd_pt.keys()} |
| model.load_state_dict(merged_sd) |
| return evaluate_model(model, dataloader, args) |
|
|