Spaces:
Running
Running
File size: 5,719 Bytes
d3fc090 dbb647b d3fc090 dbb647b d3fc090 edbf5dc c984c16 d3fc090 c984c16 d3fc090 edbf5dc c984c16 d3fc090 39f30cb d3fc090 39f30cb c4ccf03 d3fc090 c4ccf03 d3fc090 6b63ec4 c4ccf03 d3fc090 c4ccf03 c984c16 d0f5598 c984c16 edbf5dc dbb647b c4ccf03 d3fc090 c4ccf03 dbb647b c984c16 d3fc090 9014040 dbb647b 9014040 dbb647b 9014040 dbb647b 39f30cb dbb647b c4ccf03 d3fc090 c4ccf03 dbb647b d3fc090 dbb647b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import gradio as gr
import torch
from torch import nn
from einops import rearrange
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import requests
import os
import sys
import warnings
# Silenciar aviso depreciação do timm visto no HF Spaces
warnings.filterwarnings(
"ignore",
message="Importing from timm.models.layers is deprecated, please import via timm.layers",
category=FutureWarning,
)
# Garantir import local do pacote `surya` mesmo se CWD for diferente
sys.path.append(os.path.dirname(__file__))
# ================================
# 1. Baixar pesos do Surya-1.0
# ================================
MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
# Preferir checkpoint local se existir
MODEL_CANDIDATES = [
os.path.join(os.path.dirname(__file__), "surya_model.pt"),
os.path.join(os.path.dirname(__file__), "surya.366m.v1.pt"),
]
def _pick_model_file():
for p in MODEL_CANDIDATES:
if os.path.exists(p):
return p
return MODEL_CANDIDATES[-1]
MODEL_FILE = _pick_model_file()
def download_model():
if not os.path.exists(MODEL_FILE):
print("Baixando pesos do Surya-1.0...")
r = requests.get(MODEL_URL)
with open(MODEL_FILE, "wb") as f:
f.write(r.content)
print("Download concluído!")
download_model()
# ================================
# 2. Colar aqui a classe HelioSpectFormer
# ================================
# Copie todo o conteúdo que você me enviou da HelioSpectFormer aqui
# ⚠️ Substitua a seção abaixo pelo código real do repo
from surya.models.helio_spectformer import HelioSpectFormer
# se você tiver a pasta surya local
# ================================
# 3. Instanciar o modelo com parâmetros padrão
# ================================
model = HelioSpectFormer(
img_size=224,
patch_size=16,
in_chans=1,
embed_dim=368,
time_embedding={"type": "linear", "time_dim": 1},
depth=8,
n_spectral_blocks=4,
num_heads=8,
mlp_ratio=4.0,
drop_rate=0.0,
window_size=7,
dp_rank=1,
learned_flow=False,
finetune=True
)
# Carregar pesos de forma resiliente (strict=False) e logar diferenças
def _try_load_weights(m: nn.Module, path: str) -> None:
if os.environ.get("NO_WEIGHTS", "").lower() in {"1", "true", "yes"}:
print("NO_WEIGHTS=1 -> pulando carregamento de pesos")
return
try:
raw_sd = torch.load(path, map_location=torch.device('cpu'))
model_sd = m.state_dict()
filtered = {}
dropped = []
for k, v in raw_sd.items():
if k in model_sd and model_sd[k].shape == v.shape:
filtered[k] = v
else:
dropped.append((k, tuple(v.shape) if hasattr(v, 'shape') else None, tuple(model_sd.get(k, torch.tensor(())).shape) if k in model_sd else None))
missing, unexpected = m.load_state_dict(filtered, strict=False)
print(f"Pesos carregados parcialmente. Ok={len(filtered)} Missing={len(missing)} Unexpected={len(unexpected)} Dropped={len(dropped)}")
if dropped:
print("Algumas chaves foram descartadas por mismatch (ex.:)", dropped[:5])
if missing:
print("Exemplos de missing:", missing[:10])
if unexpected:
print("Exemplos de unexpected:", unexpected[:10])
except Exception as e:
print(f"Falha ao carregar pesos de {path}: {e}")
_try_load_weights(model, MODEL_FILE)
model.eval()
# ================================
# 4. Função de inferência para heatmap
# ================================
def infer_solar_image_heatmap(img):
# Pré-processamento da imagem
img_gray = img.convert("L").resize((224, 224))
img_np = np.array(img_gray)
ts_tensor = (
torch.tensor(img_np, dtype=torch.float32)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(2)
/ 255.0
) # [B=1,C=1,T=1,H=224,W=224]
batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1, 1))}
# Inferência (retorna tokens [1, L, D] com finetune=True)
with torch.no_grad():
tokens = model(batch).squeeze(0).cpu() # [L, D]
# Remover o componente estático de posição para evitar mapa "igual" entre imagens
try:
pos = model.embedding.pos_embed.squeeze(0).to(tokens.dtype).cpu() # [L, D]
if pos.shape == tokens.shape:
tokens = tokens - pos
except Exception:
pass
# Agregar energia por patch (L2) e remontar 14x14
L, D = tokens.shape
side = int(L ** 0.5) # 14 para 224/16
heat_vec = torch.sqrt((tokens**2).mean(dim=1)) # [L]
heat = heat_vec.reshape(side, side).numpy()
# Normalizar e upsample p/ 224x224 (nearest para simplicidade)
heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
heat224 = np.kron(heat, np.ones((224 // side, 224 // side)))
# Overlay sobre a imagem original
plt.figure(figsize=(5, 5))
plt.imshow(img_np, cmap="gray")
plt.imshow(heat224, cmap="inferno", alpha=0.5, vmin=0.0, vmax=1.0)
plt.axis("off")
plt.tight_layout()
return plt.gcf()
# ================================
# 5. Interface Gradio
# ================================
interface = gr.Interface(
fn=infer_solar_image_heatmap,
inputs=gr.Image(type="pil"),
outputs=gr.Plot(label="Heatmap do embedding Surya"),
title="Playground Surya-1.0 com Heatmap",
description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0"
)
interface.launch()
|