csaybar commited on
Commit
d19ad5c
·
verified ·
1 Parent(s): 28f6513

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  NonReference_RGBN_x4/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
37
  NonReference_RGBN_x4/hard_constraint.safetensor filter=lfs diff=lfs merge=lfs -text
38
  NonReference_RGBN_x4/model.safetensor filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  NonReference_RGBN_x4/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
37
  NonReference_RGBN_x4/hard_constraint.safetensor filter=lfs diff=lfs merge=lfs -text
38
  NonReference_RGBN_x4/model.safetensor filter=lfs diff=lfs merge=lfs -text
39
+ Reference_RSWIR_x2/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
40
+ Reference_RSWIR_x2/model.safetensor filter=lfs diff=lfs merge=lfs -text
Reference_RSWIR_x2/example_data.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34d594d36a1220c526a55ce086af5f3f6edfd73f4dd966464bfc8b08220854ac
3
+ size 655440
Reference_RSWIR_x2/hard_constraint.safetensor ADDED
Binary file (65.6 kB). View file
 
Reference_RSWIR_x2/load.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import safetensors.torch
3
+ import matplotlib.pyplot as plt
4
+ from sen2sr.models.opensr_baseline.swin import Swin2SR
5
+ from sen2sr.models.tricks import HardConstraint
6
+ from sen2sr.referencex2 import srmodel
7
+
8
+ # MLSTAC API -----------------------------------------------------------------------
9
+ def example_data(path: pathlib.Path, *args, **kwargs):
10
+ data_f = path / "example_data.safetensor"
11
+ sample = safetensors.torch.load_file(data_f)
12
+ return sample["lr"]
13
+
14
+ def trainable_model(path, device: str = "cpu", *args, **kwargs):
15
+ trainable_f = path / "model.safetensor"
16
+
17
+ # Load model parameters
18
+ sr_model_weights = safetensors.torch.load_file(trainable_f)
19
+ params = {
20
+ "img_size": (64, 64),
21
+ "in_channels": 10,
22
+ "out_channels": 6,
23
+ "embed_dim": 192,
24
+ "depths": [8] * 8,
25
+ "num_heads": [8] * 8,
26
+ "window_size": 4,
27
+ "mlp_ratio": 4.0,
28
+ "upscale": 1,
29
+ "resi_connection": "1conv",
30
+ "upsampler": "pixelshuffle",
31
+ }
32
+ sr_model = Swin2SR(**params)
33
+ sr_model.load_state_dict(sr_model_weights)
34
+ sr_model.to(device)
35
+
36
+ # Load HardConstraint
37
+ hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
38
+ hard_constraint = HardConstraint(
39
+ low_pass_mask=hard_constraint_weights["weights"].to(device),
40
+ bands= [0, 1, 2, 3, 4, 5],
41
+ device=device
42
+ )
43
+
44
+ return srmodel(sr_model=sr_model, hard_constraint=hard_constraint, device=device)
45
+
46
+
47
+ def compiled_model(path, device: str = "cpu", *args, **kwargs):
48
+ trainable_f = path / "model.safetensor"
49
+
50
+ # Load model parameters
51
+ sr_model_weights = safetensors.torch.load_file(trainable_f)
52
+ params = {
53
+ "img_size": (64, 64),
54
+ "in_channels": 10,
55
+ "out_channels": 6,
56
+ "embed_dim": 192,
57
+ "depths": [8] * 8,
58
+ "num_heads": [8] * 8,
59
+ "window_size": 4,
60
+ "mlp_ratio": 4.0,
61
+ "upscale": 1,
62
+ "resi_connection": "1conv",
63
+ "upsampler": "pixelshuffle",
64
+ }
65
+ sr_model = Swin2SR(**params)
66
+ sr_model.load_state_dict(sr_model_weights)
67
+ sr_model = sr_model.to(device)
68
+ sr_model = sr_model.eval()
69
+ for param in sr_model.parameters():
70
+ param.requires_grad = False
71
+
72
+ # Load HardConstraint
73
+ hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor")
74
+ hard_constraint = HardConstraint(
75
+ low_pass_mask=hard_constraint_weights["weights"].to(device),
76
+ bands= [0, 1, 2, 3, 4, 5],
77
+ device=device
78
+ )
79
+ hard_constraint = hard_constraint.eval()
80
+ for param in hard_constraint.parameters():
81
+ param.requires_grad = False
82
+
83
+ return srmodel(sr_model=sr_model, hard_constraint=hard_constraint, device=device)
84
+
85
+
86
+ def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
87
+ # Load model
88
+ model = compiled_model(path, device)
89
+
90
+ # Load data
91
+ lr = example_data(path)
92
+
93
+ # Run model
94
+ sr = model(lr.to(device))
95
+
96
+ # Create the viz
97
+ lr_rgb = lr[0, [2, 1, 0]].cpu().numpy().transpose(1, 2, 0)
98
+ sr_rgb = sr[0, [2, 1, 0]].cpu().numpy().transpose(1, 2, 0)
99
+
100
+ lr_swirs = lr[0, [9, 8, 7]].cpu().numpy().transpose(1, 2, 0)
101
+ sr_swirs = sr[0, [9, 8, 7]].cpu().numpy().transpose(1, 2, 0)
102
+
103
+ lr_reds = lr[0, [6, 5, 4]].cpu().numpy().transpose(1, 2, 0)
104
+ sr_reds = sr[0, [6, 5, 4]].cpu().numpy().transpose(1, 2, 0)
105
+
106
+
107
+ #Display results
108
+ lr_slice = slice(16, 32+80)
109
+ hr_slice = slice(lr_slice.start*1, lr_slice.stop*1)
110
+ fig, ax = plt.subplots(3, 2, figsize=(8, 12))
111
+ ax = ax.flatten()
112
+ ax[0].imshow(lr_rgb[lr_slice]*2)
113
+ ax[0].set_title("LR RGB")
114
+ ax[1].imshow(sr_rgb[hr_slice]*2)
115
+ ax[1].set_title("SR RGB")
116
+ ax[2].imshow(lr_swirs[lr_slice]*2)
117
+ ax[2].set_title("LR SWIR")
118
+ ax[3].imshow(sr_swirs[hr_slice]*2)
119
+ ax[3].set_title("SR SWIR")
120
+ ax[4].imshow(lr_reds[lr_slice]*2)
121
+ ax[4].set_title("LR RED")
122
+ ax[5].imshow(sr_reds[hr_slice]*2)
123
+ ax[5].set_title("SR RED")
124
+ for a in ax:
125
+ a.axis("off")
126
+ fig.tight_layout()
127
+ return fig
Reference_RSWIR_x2/mlm.json ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "Feature",
3
+ "stac_version": "1.1.0",
4
+ "stac_extensions": [
5
+ "https://stac-extensions.github.io/mlm/v1.4.0/schema.json"
6
+ ],
7
+ "id": "SPAN_model",
8
+ "geometry": {
9
+ "type": "Polygon",
10
+ "coordinates": [
11
+ [
12
+ [
13
+ -180.0,
14
+ -90.0
15
+ ],
16
+ [
17
+ -180.0,
18
+ 90.0
19
+ ],
20
+ [
21
+ 180.0,
22
+ 90.0
23
+ ],
24
+ [
25
+ 180.0,
26
+ -90.0
27
+ ],
28
+ [
29
+ -180.0,
30
+ -90.0
31
+ ]
32
+ ]
33
+ ]
34
+ },
35
+ "bbox": [
36
+ -180,
37
+ -90,
38
+ 180,
39
+ 90
40
+ ],
41
+ "properties": {
42
+ "start_datetime": "1900-01-01T00:00:00Z",
43
+ "end_datetime": "9999-01-01T00:00:00Z",
44
+ "description": "A Swift Parameter-free Attention Network (SPAN) trained on the SEN2NAIPv2 dataset to enhance RSWIRs Sentinel-2 bands, improving spatial resolution from 20 meters to 10 meters.",
45
+ "forward_backward_pass": {
46
+ "32": 24.957696,
47
+ "64": 97.858304,
48
+ "128": 387.576576,
49
+ "256": 1542.681344,
50
+ "512": 6155.563776
51
+ },
52
+ "dependencies": [
53
+ "torch",
54
+ "safetensors.torch",
55
+ "sen2sr"
56
+ ],
57
+ "mlm:framework": "pytorch",
58
+ "mlm:framework_version": "2.1.2+cu121",
59
+ "file:size": 1889984,
60
+ "mlm:memory_size": 1,
61
+ "mlm:accelerator": "cuda",
62
+ "mlm:accelerator_constrained": false,
63
+ "mlm:accelerator_summary": "Unknown",
64
+ "mlm:name": "CNN_Light_SR",
65
+ "mlm:architecture": "SPAN",
66
+ "mlm:tasks": [
67
+ "super-resolution"
68
+ ],
69
+ "mlm:input": [
70
+ {
71
+ "name": "10 Band Sentinel-2",
72
+ "bands": [
73
+ "B04",
74
+ "B03",
75
+ "B02",
76
+ "B08",
77
+ "B05",
78
+ "B06",
79
+ "B07",
80
+ "B8A",
81
+ "B11",
82
+ "B12"
83
+ ],
84
+ "input": {
85
+ "shape": [
86
+ -1,
87
+ 10,
88
+ 128,
89
+ 128
90
+ ],
91
+ "dim_order": [
92
+ "batch",
93
+ "channel",
94
+ "height",
95
+ "width"
96
+ ],
97
+ "data_type": "float16"
98
+ },
99
+ "pre_processing_function": null
100
+ }
101
+ ],
102
+ "mlm:output": [
103
+ {
104
+ "name": "super-resolution",
105
+ "bands": [
106
+ "B05",
107
+ "B06",
108
+ "B07",
109
+ "B8A",
110
+ "B11",
111
+ "B12"
112
+ ],
113
+ "tasks": [
114
+ "super-resolution"
115
+ ],
116
+ "result": {
117
+ "shape": [
118
+ -1,
119
+ 6,
120
+ 128,
121
+ 128
122
+ ],
123
+ "dim_order": [
124
+ "batch",
125
+ "channel",
126
+ "height",
127
+ "width"
128
+ ],
129
+ "data_type": "float16"
130
+ },
131
+ "classification:classes": [],
132
+ "post_processing_function": null
133
+ }
134
+ ],
135
+ "mlm:total_parameters": 472496,
136
+ "mlm:pretrained": true,
137
+ "datetime": null
138
+ },
139
+ "links": [],
140
+ "assets": {
141
+ "trainable": {
142
+ "href": "https://huggingface.co/tacofoundation/SEN2SR/resolve/main/SEN2SR/Reference_RSWIR_x2/model.safetensor",
143
+ "type": "application/octet-stream; application=safetensor",
144
+ "title": "Pytorch model weights checkpoint",
145
+ "description": "The weights of the model in safetensor format.",
146
+ "mlm:artifact_type": "safetensor.torch.save_file",
147
+ "roles": [
148
+ "mlm:model",
149
+ "mlm:weights",
150
+ "data"
151
+ ]
152
+ },
153
+ "hardconstraint": {
154
+ "href": "https://huggingface.co/tacofoundation/SEN2SR/resolve/main/SEN2SR/Reference_RSWIR_x2/hard_constraint.safetensor",
155
+ "type": "application/octet-stream; application=safetensor",
156
+ "title": "Pytorch hard constraint weights checkpoint. It is used to load faster the hard constraint module.",
157
+ "description": "The weights of the model in safetensor format.",
158
+ "mlm:artifact_type": "safetensor.torch.save_file",
159
+ "roles": [
160
+ "mlm:model",
161
+ "mlm:weights",
162
+ "data"
163
+ ]
164
+ },
165
+ "source_code": {
166
+ "href": "https://huggingface.co/tacofoundation/SEN2SR/resolve/main/SEN2SR/Reference_RSWIR_x2/load.py",
167
+ "type": "text/x-python",
168
+ "title": "Model load script",
169
+ "description": "Python script to load the model.",
170
+ "roles": [
171
+ "mlm:source_code",
172
+ "code"
173
+ ]
174
+ },
175
+ "example_data": {
176
+ "href": "https://huggingface.co/tacofoundation/SEN2SR/resolve/main/SEN2SR/Reference_RSWIR_x2/example_data.safetensor",
177
+ "type": "application/octet-stream; application=safetensors",
178
+ "title": "Example Sentinel-2 image",
179
+ "description": "Example Sentinel-2 image for model inference.",
180
+ "roles": [
181
+ "mlm:example_data",
182
+ "data"
183
+ ]
184
+ }
185
+ },
186
+ "collection": "ml-model"
187
+ }
Reference_RSWIR_x2/model.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d1b93248ca0bb24cc20a6c2e83b2acd70517357fa322f9051af8f378d54eecc
3
+ size 137790600