File size: 5,719 Bytes
d3fc090
dbb647b
d3fc090
 
dbb647b
 
 
d3fc090
 
edbf5dc
 
 
 
 
 
 
 
 
 
 
 
c984c16
 
d3fc090
c984c16
d3fc090
edbf5dc
 
 
 
 
 
 
 
 
 
 
 
 
 
c984c16
d3fc090
 
 
 
 
 
 
39f30cb
d3fc090
39f30cb
c4ccf03
d3fc090
c4ccf03
d3fc090
 
6b63ec4
 
c4ccf03
 
d3fc090
c4ccf03
c984c16
 
 
 
d0f5598
c984c16
 
 
 
 
 
 
 
 
 
 
 
edbf5dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb647b
 
c4ccf03
d3fc090
c4ccf03
dbb647b
c984c16
d3fc090
9014040
 
 
 
 
 
 
 
 
 
 
dbb647b
9014040
dbb647b
9014040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb647b
39f30cb
dbb647b
c4ccf03
d3fc090
c4ccf03
dbb647b
 
 
 
d3fc090
dbb647b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
import torch
from torch import nn
from einops import rearrange
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import requests
import os
import sys
import warnings

# Silenciar aviso depreciação do timm visto no HF Spaces
warnings.filterwarnings(
    "ignore",
    message="Importing from timm.models.layers is deprecated, please import via timm.layers",
    category=FutureWarning,
)

# Garantir import local do pacote `surya` mesmo se CWD for diferente
sys.path.append(os.path.dirname(__file__))

# ================================
# 1. Baixar pesos do Surya-1.0
# ================================
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"

# Preferir checkpoint local se existir
MODEL_CANDIDATES = [
    os.path.join(os.path.dirname(__file__), "surya_model.pt"),
    os.path.join(os.path.dirname(__file__), "surya.366m.v1.pt"),
]

def _pick_model_file():
    for p in MODEL_CANDIDATES:
        if os.path.exists(p):
            return p
    return MODEL_CANDIDATES[-1]

MODEL_FILE = _pick_model_file()

def download_model():
    if not os.path.exists(MODEL_FILE):
        print("Baixando pesos do Surya-1.0...")
        r = requests.get(MODEL_URL)
        with open(MODEL_FILE, "wb") as f:
            f.write(r.content)
        print("Download concluído!")

download_model()

# ================================
# 2. Colar aqui a classe HelioSpectFormer
# ================================
# Copie todo o conteúdo que você me enviou da HelioSpectFormer aqui
# ⚠️ Substitua a seção abaixo pelo código real do repo
from surya.models.helio_spectformer import HelioSpectFormer
 # se você tiver a pasta surya local

# ================================
# 3. Instanciar o modelo com parâmetros padrão
# ================================
model = HelioSpectFormer(
    img_size=224,
    patch_size=16,
    in_chans=1,
    embed_dim=368,
    time_embedding={"type": "linear", "time_dim": 1},
    depth=8,
    n_spectral_blocks=4,
    num_heads=8,
    mlp_ratio=4.0,
    drop_rate=0.0,
    window_size=7,
    dp_rank=1,
    learned_flow=False,
    finetune=True
)

# Carregar pesos de forma resiliente (strict=False) e logar diferenças
def _try_load_weights(m: nn.Module, path: str) -> None:
    if os.environ.get("NO_WEIGHTS", "").lower() in {"1", "true", "yes"}:
        print("NO_WEIGHTS=1 -> pulando carregamento de pesos")
        return
    try:
        raw_sd = torch.load(path, map_location=torch.device('cpu'))
        model_sd = m.state_dict()
        filtered = {}
        dropped = []
        for k, v in raw_sd.items():
            if k in model_sd and model_sd[k].shape == v.shape:
                filtered[k] = v
            else:
                dropped.append((k, tuple(v.shape) if hasattr(v, 'shape') else None, tuple(model_sd.get(k, torch.tensor(())).shape) if k in model_sd else None))
        missing, unexpected = m.load_state_dict(filtered, strict=False)
        print(f"Pesos carregados parcialmente. Ok={len(filtered)} Missing={len(missing)} Unexpected={len(unexpected)} Dropped={len(dropped)}")
        if dropped:
            print("Algumas chaves foram descartadas por mismatch (ex.:)", dropped[:5])
        if missing:
            print("Exemplos de missing:", missing[:10])
        if unexpected:
            print("Exemplos de unexpected:", unexpected[:10])
    except Exception as e:
        print(f"Falha ao carregar pesos de {path}: {e}")

_try_load_weights(model, MODEL_FILE)
model.eval()

# ================================
# 4. Função de inferência para heatmap
# ================================
def infer_solar_image_heatmap(img):
    # Pré-processamento da imagem
    img_gray = img.convert("L").resize((224, 224))
    img_np = np.array(img_gray)
    ts_tensor = (
        torch.tensor(img_np, dtype=torch.float32)
        .unsqueeze(0)
        .unsqueeze(0)
        .unsqueeze(2)
        / 255.0
    )  # [B=1,C=1,T=1,H=224,W=224]
    batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1, 1))}

    # Inferência (retorna tokens [1, L, D] com finetune=True)
    with torch.no_grad():
        tokens = model(batch).squeeze(0).cpu()  # [L, D]

    # Remover o componente estático de posição para evitar mapa "igual" entre imagens
    try:
        pos = model.embedding.pos_embed.squeeze(0).to(tokens.dtype).cpu()  # [L, D]
        if pos.shape == tokens.shape:
            tokens = tokens - pos
    except Exception:
        pass

    # Agregar energia por patch (L2) e remontar 14x14
    L, D = tokens.shape
    side = int(L ** 0.5)  # 14 para 224/16
    heat_vec = torch.sqrt((tokens**2).mean(dim=1))  # [L]
    heat = heat_vec.reshape(side, side).numpy()

    # Normalizar e upsample p/ 224x224 (nearest para simplicidade)
    heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
    heat224 = np.kron(heat, np.ones((224 // side, 224 // side)))

    # Overlay sobre a imagem original
    plt.figure(figsize=(5, 5))
    plt.imshow(img_np, cmap="gray")
    plt.imshow(heat224, cmap="inferno", alpha=0.5, vmin=0.0, vmax=1.0)
    plt.axis("off")
    plt.tight_layout()
    return plt.gcf()

# ================================
# 5. Interface Gradio
# ================================
interface = gr.Interface(
    fn=infer_solar_image_heatmap,
    inputs=gr.Image(type="pil"),
    outputs=gr.Plot(label="Heatmap do embedding Surya"),
    title="Playground Surya-1.0 com Heatmap",
    description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)

interface.launch()