| import torch |
| from typing import Tuple, Callable |
| def hacer_nada(x: torch.Tensor, modo: str = None): |
| return x |
| def brujeria_mps(entrada, dim, indice): |
| if entrada.shape[-1] == 1: |
| return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1) |
| else: |
| return torch.gather(entrada, dim, indice) |
| def emparejamiento_suave_aleatorio_2d( |
| metrica: torch.Tensor, |
| ancho: int, |
| alto: int, |
| paso_x: int, |
| paso_y: int, |
| radio: int, |
| sin_aleatoriedad: bool = False, |
| generador: torch.Generator = None |
| ) -> Tuple[Callable, Callable]: |
| lote, num_nodos, _ = metrica.shape |
| if radio <= 0: |
| return hacer_nada, hacer_nada |
| recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather |
| with torch.no_grad(): |
| alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x |
| if sin_aleatoriedad: |
| indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64) |
| else: |
| indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device) |
| vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64) |
| vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype)) |
| vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x) |
| if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho: |
| buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64) |
| buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice |
| else: |
| buffer_indice = vista_buffer_indice |
| indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1) |
| del buffer_indice, vista_buffer_indice |
| num_destino = alto_paso_y * ancho_paso_x |
| indices_a = indice_aleatorio[:, num_destino:, :] |
| indices_b = indice_aleatorio[:, :num_destino, :] |
| def dividir(x): |
| canales = x.shape[-1] |
| origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales)) |
| destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales)) |
| return origen, destino |
| metrica = metrica / metrica.norm(dim=-1, keepdim=True) |
| a, b = dividir(metrica) |
| puntuaciones = a @ b.transpose(-1, -2) |
| radio = min(a.shape[1], radio) |
| nodo_max, nodo_indice = puntuaciones.max(dim=-1) |
| indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None] |
| indice_no_emparejado = indice_borde[..., radio:, :] |
| indice_origen = indice_borde[..., :radio, :] |
| indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen) |
| def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor: |
| origen, destino = dividir(x) |
| n, t1, c = origen.shape |
| no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c)) |
| origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c)) |
| destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo) |
| return torch.cat([no_emparejado, destino], dim=1) |
| def desfusionar(x: torch.Tensor) -> torch.Tensor: |
| longitud_no_emparejado = indice_no_emparejado.shape[1] |
| no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :] |
| _, _, c = no_emparejado.shape |
| origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c)) |
| salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype) |
| salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino) |
| salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado) |
| salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen) |
| return salida |
| return fusionar, desfusionar |