import torch
from torch import nn
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
import comfy.ops
ops = comfy.ops.manual_cast

def exists(val):
    return val is not None


def uniq(arr):
    return{el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = ops.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * torch.nn.functional.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            ops.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            ops.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class GatedCrossAttentionDense(nn.Module):
    def __init__(self, query_dim, context_dim, n_heads, d_head):
        super().__init__()

        self.attn = CrossAttention(
            query_dim=query_dim,
            context_dim=context_dim,
            heads=n_heads,
            dim_head=d_head,
            operations=ops)
        self.ff = FeedForward(query_dim, glu=True)

        self.norm1 = ops.LayerNorm(query_dim)
        self.norm2 = ops.LayerNorm(query_dim)

        self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
        self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))

        # this can be useful: we can externally change magnitude of tanh(alpha)
        # for example, when it is set to 0, then the entire model is same as
        # original one
        self.scale = 1

    def forward(self, x, objs):

        x = x + self.scale * \
            torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
        x = x + self.scale * \
            torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))

        return x


class GatedSelfAttentionDense(nn.Module):
    def __init__(self, query_dim, context_dim, n_heads, d_head):
        super().__init__()

        # we need a linear projection since we need cat visual feature and obj
        # feature
        self.linear = ops.Linear(context_dim, query_dim)

        self.attn = CrossAttention(
            query_dim=query_dim,
            context_dim=query_dim,
            heads=n_heads,
            dim_head=d_head,
            operations=ops)
        self.ff = FeedForward(query_dim, glu=True)

        self.norm1 = ops.LayerNorm(query_dim)
        self.norm2 = ops.LayerNorm(query_dim)

        self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
        self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))

        # this can be useful: we can externally change magnitude of tanh(alpha)
        # for example, when it is set to 0, then the entire model is same as
        # original one
        self.scale = 1

    def forward(self, x, objs):

        N_visual = x.shape[1]
        objs = self.linear(objs)

        x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
            self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
        x = x + self.scale * \
            torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))

        return x


class GatedSelfAttentionDense2(nn.Module):
    def __init__(self, query_dim, context_dim, n_heads, d_head):
        super().__init__()

        # we need a linear projection since we need cat visual feature and obj
        # feature
        self.linear = ops.Linear(context_dim, query_dim)

        self.attn = CrossAttention(
            query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
        self.ff = FeedForward(query_dim, glu=True)

        self.norm1 = ops.LayerNorm(query_dim)
        self.norm2 = ops.LayerNorm(query_dim)

        self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
        self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))

        # this can be useful: we can externally change magnitude of tanh(alpha)
        # for example, when it is set to 0, then the entire model is same as
        # original one
        self.scale = 1

    def forward(self, x, objs):

        B, N_visual, _ = x.shape
        B, N_ground, _ = objs.shape

        objs = self.linear(objs)

        # sanity check
        size_v = math.sqrt(N_visual)
        size_g = math.sqrt(N_ground)
        assert int(size_v) == size_v, "Visual tokens must be square rootable"
        assert int(size_g) == size_g, "Grounding tokens must be square rootable"
        size_v = int(size_v)
        size_g = int(size_g)

        # select grounding token and resize it to visual token size as residual
        out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
            :, N_visual:, :]
        out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
        out = torch.nn.functional.interpolate(
            out, (size_v, size_v), mode='bicubic')
        residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)

        # add residual to visual feature
        x = x + self.scale * torch.tanh(self.alpha_attn) * residual
        x = x + self.scale * \
            torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))

        return x


class FourierEmbedder():
    def __init__(self, num_freqs=64, temperature=100):

        self.num_freqs = num_freqs
        self.temperature = temperature
        self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)

    @torch.no_grad()
    def __call__(self, x, cat_dim=-1):
        "x: arbitrary shape of tensor. dim: cat dim"
        out = []
        for freq in self.freq_bands:
            out.append(torch.sin(freq * x))
            out.append(torch.cos(freq * x))
        return torch.cat(out, cat_dim)


class PositionNet(nn.Module):
    def __init__(self, in_dim, out_dim, fourier_freqs=8):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
        self.position_dim = fourier_freqs * 2 * 4  # 2 is sin&cos, 4 is xyxy

        self.linears = nn.Sequential(
            ops.Linear(self.in_dim + self.position_dim, 512),
            nn.SiLU(),
            ops.Linear(512, 512),
            nn.SiLU(),
            ops.Linear(512, out_dim),
        )

        self.null_positive_feature = torch.nn.Parameter(
            torch.zeros([self.in_dim]))
        self.null_position_feature = torch.nn.Parameter(
            torch.zeros([self.position_dim]))

    def forward(self, boxes, masks, positive_embeddings):
        B, N, _ = boxes.shape
        masks = masks.unsqueeze(-1)
        positive_embeddings = positive_embeddings

        # embedding position (it may includes padding as placeholder)
        xyxy_embedding = self.fourier_embedder(boxes)  # B*N*4 --> B*N*C

        # learnable null embedding
        positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
        xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)

        # replace padding with learnable null embedding
        positive_embeddings = positive_embeddings * \
            masks + (1 - masks) * positive_null
        xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null

        objs = self.linears(
            torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
        assert objs.shape == torch.Size([B, N, self.out_dim])
        return objs


class Gligen(nn.Module):
    def __init__(self, modules, position_net, key_dim):
        super().__init__()
        self.module_list = nn.ModuleList(modules)
        self.position_net = position_net
        self.key_dim = key_dim
        self.max_objs = 30
        self.current_device = torch.device("cpu")

    def _set_position(self, boxes, masks, positive_embeddings):
        objs = self.position_net(boxes, masks, positive_embeddings)
        def func(x, extra_options):
            key = extra_options["transformer_index"]
            module = self.module_list[key]
            return module(x, objs.to(device=x.device, dtype=x.dtype))
        return func

    def set_position(self, latent_image_shape, position_params, device):
        batch, c, h, w = latent_image_shape
        masks = torch.zeros([self.max_objs], device="cpu")
        boxes = []
        positive_embeddings = []
        for p in position_params:
            x1 = (p[4]) / w
            y1 = (p[3]) / h
            x2 = (p[4] + p[2]) / w
            y2 = (p[3] + p[1]) / h
            masks[len(boxes)] = 1.0
            boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
            positive_embeddings += [p[0]]
        append_boxes = []
        append_conds = []
        if len(boxes) < self.max_objs:
            append_boxes = [torch.zeros(
                [self.max_objs - len(boxes), 4], device="cpu")]
            append_conds = [torch.zeros(
                [self.max_objs - len(boxes), self.key_dim], device="cpu")]

        box_out = torch.cat(
            boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
        masks = masks.unsqueeze(0).repeat(batch, 1)
        conds = torch.cat(positive_embeddings +
                          append_conds).unsqueeze(0).repeat(batch, 1, 1)
        return self._set_position(
            box_out.to(device),
            masks.to(device),
            conds.to(device))

    def set_empty(self, latent_image_shape, device):
        batch, c, h, w = latent_image_shape
        masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
        box_out = torch.zeros([self.max_objs, 4],
                              device="cpu").repeat(batch, 1, 1)
        conds = torch.zeros([self.max_objs, self.key_dim],
                            device="cpu").repeat(batch, 1, 1)
        return self._set_position(
            box_out.to(device),
            masks.to(device),
            conds.to(device))


def load_gligen(sd):
    sd_k = sd.keys()
    output_list = []
    key_dim = 768
    for a in ["input_blocks", "middle_block", "output_blocks"]:
        for b in range(20):
            k_temp = filter(lambda k: "{}.{}.".format(a, b)
                            in k and ".fuser." in k, sd_k)
            k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)

            n_sd = {}
            for k in k_temp:
                n_sd[k[1]] = sd[k[0]]
            if len(n_sd) > 0:
                query_dim = n_sd["linear.weight"].shape[0]
                key_dim = n_sd["linear.weight"].shape[1]

                if key_dim == 768:  # SD1.x
                    n_heads = 8
                    d_head = query_dim // n_heads
                else:
                    d_head = 64
                    n_heads = query_dim // d_head

                gated = GatedSelfAttentionDense(
                    query_dim, key_dim, n_heads, d_head)
                gated.load_state_dict(n_sd, strict=False)
                output_list.append(gated)

    if "position_net.null_positive_feature" in sd_k:
        in_dim = sd["position_net.null_positive_feature"].shape[0]
        out_dim = sd["position_net.linears.4.weight"].shape[0]

        class WeightsLoader(torch.nn.Module):
            pass
        w = WeightsLoader()
        w.position_net = PositionNet(in_dim, out_dim)
        w.load_state_dict(sd, strict=False)

    gligen = Gligen(output_list, w.position_net, key_dim)
    return gligen