Spaces:
Runtime error
Runtime error
import math | |
from typing import Tuple | |
import torch | |
from torch import Tensor | |
__all__ = [ | |
"weighted_sum", | |
"weighted_subtraction", | |
"tensor_sum", | |
"add_difference", | |
"sum_twice", | |
"triple_sum", | |
"euclidean_add_difference", | |
"multiply_difference", | |
"top_k_tensor_sum", | |
"similarity_add_difference", | |
"distribution_crossover", | |
"ties_add_difference", | |
] | |
EPSILON = 1e-10 # Define a small constant EPSILON to prevent division by zero | |
def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Basic Merge: | |
alpha 0 returns Primary Model | |
alpha 1 returns Secondary Model | |
""" | |
return (1 - alpha) * a + alpha * b | |
def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
The inverse of a Weighted Sum Merge | |
Returns Primary Model when alpha*beta = 0 | |
High values of alpha*beta are likely to break the merged model | |
""" | |
# Adjust beta if both alpha and beta are 1.0 to avoid division by zero | |
if alpha == 1.0 and beta == 1.0: | |
beta -= EPSILON | |
return (a - alpha * beta * b) / (1 - alpha * beta) | |
def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Takes a slice of Secondary Model and pastes it into Primary Model | |
Alpha sets the width of the slice | |
Beta sets the start point of the slice | |
ie Alpha = 0.5 Beta = 0.25 is (ABBA) Alpha = 0.25 Beta = 0 is (BAAA) | |
""" | |
if alpha + beta <= 1: | |
tt = a.clone() | |
talphas = int(a.shape[0] * beta) | |
talphae = int(a.shape[0] * (alpha + beta)) | |
tt[talphas:talphae] = b[talphas:talphae].clone() | |
else: | |
talphas = int(a.shape[0] * (alpha + beta - 1)) | |
talphae = int(a.shape[0] * beta) | |
tt = b.clone() | |
tt[talphas:talphae] = a[talphas:talphae].clone() | |
return tt | |
def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Classic Add Difference Merge | |
""" | |
return a + alpha * (b - c) | |
def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Stacked Basic Merge: | |
Equivalent to Merging Primary and Secondary @ alpha | |
Then merging the result with Tertiary @ beta | |
""" | |
return (1 - beta) * ((1 - alpha) * a + alpha * b) + beta * c | |
def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Weights Secondary and Tertiary at alpha and beta respectively | |
Fills in the rest with Primary | |
Expect odd results if alpha + beta > 1 as Primary will be merged with a negative ratio | |
""" | |
return (1 - alpha - beta) * a + alpha * b + beta * c | |
def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Subtract Primary and Secondary from Tertiary | |
Compare the remainders via Euclidean distance | |
Add to Tertiary | |
Note: Slow | |
""" | |
a_diff = a.float() - c.float() | |
b_diff = b.float() - c.float() | |
a_diff = torch.nan_to_num(a_diff / torch.linalg.norm(a_diff)) | |
b_diff = torch.nan_to_num(b_diff / torch.linalg.norm(b_diff)) | |
distance = (1 - alpha) * a_diff**2 + alpha * b_diff**2 | |
distance = torch.sqrt(distance) | |
sum_diff = weighted_sum(a.float(), b.float(), alpha) - c.float() | |
distance = torch.copysign(distance, sum_diff) | |
target_norm = torch.linalg.norm(sum_diff) | |
return c + distance / torch.linalg.norm(distance) * target_norm | |
def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Similar to Add Difference but with geometric mean instead of arithmatic mean | |
""" | |
diff_a = torch.pow(torch.abs(a.float() - c), (1 - alpha)) | |
diff_b = torch.pow(torch.abs(b.float() - c), alpha) | |
difference = torch.copysign(diff_a * diff_b, weighted_sum(a, b, beta) - c) | |
return c + difference.to(c.dtype) | |
def top_k_tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Redistributes the largest weights of Secondary Model into Primary Model | |
""" | |
a_flat = torch.flatten(a) | |
a_dist = torch.msort(a_flat) | |
b_indices = torch.argsort(torch.flatten(b), stable=True) | |
redist_indices = torch.argsort(b_indices) | |
start_i, end_i, region_is_inverted = ratio_to_region(alpha, beta, torch.numel(a)) | |
start_top_k = kth_abs_value(a_dist, start_i) | |
end_top_k = kth_abs_value(a_dist, end_i) | |
indices_mask = (start_top_k < torch.abs(a_dist)) & (torch.abs(a_dist) <= end_top_k) | |
if region_is_inverted: | |
indices_mask = ~indices_mask | |
indices_mask = torch.gather(indices_mask.float(), 0, redist_indices) | |
a_redist = torch.gather(a_dist, 0, redist_indices) | |
a_redist = (1 - indices_mask) * a_flat + indices_mask * a_redist | |
return a_redist.reshape_as(a) | |
def kth_abs_value(a: Tensor, k: int) -> Tensor: | |
if k <= 0: | |
return torch.tensor(-1, device=a.device) | |
else: | |
return torch.kthvalue(torch.abs(a.float()), k)[0] | |
def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: | |
if width < 0: | |
offset += width | |
width = -width | |
width = min(width, 1) | |
if offset < 0: | |
offset = 1 + offset - int(offset) | |
offset = math.fmod(offset, 1.0) | |
if width + offset <= 1: | |
inverted = False | |
start = offset * n | |
end = (width + offset) * n | |
else: | |
inverted = True | |
start = (width + offset - 1) * n | |
end = offset * n | |
return round(start), round(end), inverted | |
def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
Weighted Sum where A and B are similar and Add Difference where A and B are dissimilar | |
""" | |
threshold = torch.maximum(torch.abs(a), torch.abs(b)) | |
similarity = ((a * b / threshold**2) + 1) / 2 | |
similarity = torch.nan_to_num(similarity * beta, nan=beta) | |
ab_diff = a + alpha * (b - c) | |
ab_sum = (1 - alpha / 2) * a + (alpha / 2) * b | |
return (1 - similarity) * ab_diff + similarity * ab_sum | |
def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs): # pylint: disable=unused-argument | |
""" | |
From the creator: | |
It's Primary high-passed + Secondary low-passed. Takes the fourrier transform of the weights of | |
Primary and Secondary when ordered with respect to Tertiary. Split the frequency domain | |
using a linear function. Alpha is the split frequency and Beta is the inclination of the line. | |
add everything under the line as the contribution of Primary and everything over the line as the contribution of Secondary | |
""" | |
if a.shape == (): | |
return alpha * a + (1 - alpha) * b | |
c_indices = torch.argsort(torch.flatten(c)) | |
a_dist = torch.gather(torch.flatten(a), 0, c_indices) | |
b_dist = torch.gather(torch.flatten(b), 0, c_indices) | |
a_dft = torch.fft.rfft(a_dist.float()) | |
b_dft = torch.fft.rfft(b_dist.float()) | |
dft_filter = torch.arange(0, torch.numel(a_dft), device=a_dft.device).float() | |
dft_filter /= torch.numel(a_dft) | |
if beta > EPSILON: | |
dft_filter = (dft_filter - alpha) / beta + 1 / 2 | |
dft_filter = torch.clamp(dft_filter, 0.0, 1.0) | |
else: | |
dft_filter = (dft_filter >= alpha).float() | |
x_dft = (1 - dft_filter) * a_dft + dft_filter * b_dft | |
x_dist = torch.fft.irfft(x_dft, a_dist.shape[0]) | |
x_values = torch.gather(x_dist, 0, torch.argsort(c_indices)) | |
return x_values.reshape_as(a) | |
def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
""" | |
An implementation of arXiv:2306.01708 | |
""" | |
deltas = [] | |
signs = [] | |
for m in [a, b]: | |
deltas.append(filter_top_k(m - c, beta)) | |
signs.append(torch.sign(deltas[-1])) | |
signs = torch.stack(signs, dim=0) | |
final_sign = torch.sign(torch.sum(signs, dim=0)) | |
delta_filters = (signs == final_sign).float() | |
res = torch.zeros_like(c, device=c.device) | |
for delta_filter, delta in zip(delta_filters, deltas): | |
res += delta_filter * delta | |
param_count = torch.sum(delta_filters, dim=0) | |
return c + alpha * torch.nan_to_num(res / param_count) | |
def filter_top_k(a: Tensor, k: float): | |
k = max(int((1 - k) * torch.numel(a)), 1) | |
k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) | |
top_k_filter = (torch.abs(a) >= k_value).float() | |
return a * top_k_filter | |