File size: 3,798 Bytes
79c5088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch 
import torch.nn.functional as F
import io
import cv2
import numpy as np
from PIL import Image


def normalize(
    z_t,
    i,
    max_norm_zs,
):
    max_norm = max_norm_zs[i]
    if max_norm < 0:
        return z_t, 1

    norm = torch.norm(z_t)
    if norm < max_norm:
        return z_t, 1

    coeff = max_norm / norm
    z_t = z_t * coeff
    return z_t, coeff

def normalize2(x, dim):
    x_mean = x.mean(dim=dim, keepdim=True)
    x_std = x.std(dim=dim, keepdim=True)
    x_normalized = (x - x_mean) / x_std
    return x_normalized

def find_lambda_via_newton_batched(Qp, K_source, K_target, max_iter=50, tol=1e-7):
    dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source) # shape [B]
    dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target) # shape [B]
    X = torch.exp(dot_QpK_source)

    lmbd = torch.zeros([1], device=Qp.device, dtype=Qp.dtype) + 0.7
    for it in range(max_iter):
        y = torch.exp(lmbd * dot_QpK_target)
        Z = (X + y).sum(dim=(2), keepdim=True)
        x = X / Z
        y = y / Z
        val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum()

        grad = - (dot_QpK_target * y).sum()

        if not (val.abs() > tol and grad.abs() > 1e-12):
            break

        lmbd = lmbd - val / grad
        if lmbd.item() < 0.4:
            return 0.1
        elif lmbd.item() > 0.9:
            return 0.65
        
    return lmbd.item()

def find_lambda_via_super_halley(Qp, K_source, K_target, max_iter=50, tol=1e-7):
    dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source)
    dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target)
    X = torch.exp(dot_QpK_source)

    lmbd = torch.zeros([], device=Qp.device, dtype=Qp.dtype) + 0.8

    for it in range(max_iter):
        y = torch.exp(lmbd * dot_QpK_target)

        Z = (X + y).sum(dim=2, keepdim=True)
        x = X / Z
        y = y / Z
        
        val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum()

        grad = - (dot_QpK_target * y).sum()

        f2 = - (dot_QpK_target**2 * y).sum()

        if not (val.abs() > tol and grad.abs() > 1e-12):
            break

        denom = grad**2 - val * f2
        if denom.abs() < 1e-20:
            break

        update = (val * grad) / denom
        lmbd = lmbd - update

        print(f"iter={it}, λ={lmbd.item():.6f}, val={val.item():.6e}, grad={grad.item():.6e}")

    return lmbd

def find_smallest_key_with_suffix(features_dict: dict, suffix: str = "_1") -> str:
        smallest_key = None
        smallest_number = float('inf')
        for key in features_dict.keys():
            if key.endswith(suffix):
                try:
                    number = int(key.split('_')[0])
                    if number < smallest_number:
                        smallest_number = number
                        smallest_key = key
                except ValueError:
                    continue
        return smallest_key

def extract_mask(masks, original_width, original_height):
    if not masks:
        return None

    combined_mask = torch.zeros(512, 512)
    scale_x = 512 / original_width
    scale_y = 512 / original_height

    for mask in masks:
        start_x, start_y = mask["start_point"]
        end_x, end_y = mask["end_point"]

        start_x, end_x = min(start_x, end_x), max(start_x, end_x)
        start_y, end_y = min(start_y, end_y), max(start_y, end_y)

        scaled_start_x, scaled_start_y = int(start_x * scale_x), int(start_y * scale_y)
        scaled_end_x, scaled_end_y = int(end_x * scale_x), int(end_y * scale_y)
        combined_mask[scaled_start_y:scaled_end_y, scaled_start_x:scaled_end_x] += 1

    binary_mask = (combined_mask > 0).float()
    resized_mask = F.interpolate(binary_mask[None, None, :, :], size=(64, 64), mode="nearest")[0, 0]

    return resized_mask