linked-liszt commited on
Commit
8a10305
·
verified ·
1 Parent(s): e4000bb

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. LICENSE +13 -0
  2. README.md +155 -0
  3. config.json +55 -0
  4. example_inference.py +44 -0
  5. maxsub.json +232 -0
  6. model.py +408 -0
  7. model.safetensors +3 -0
LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2026 UChicago Argonne, LLC. All rights reserved.
4
+
5
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6
+
7
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
8
+
9
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10
+
11
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
12
+
13
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ language:
4
+ - en
5
+ tags:
6
+ - pytorch
7
+ - materials-science
8
+ - crystallography
9
+ - x-ray-diffraction
10
+ - pxrd
11
+ - convnext
12
+ - arxiv:2603.23367
13
+ datasets:
14
+ - materials-project
15
+ metrics:
16
+ - accuracy
17
+ - mae
18
+ pipeline_tag: other
19
+ ---
20
+
21
+ # AlphaDiffract — Open Weights
22
+
23
+ [arXiv](https://arxiv.org/abs/2603.23367) | [GitHub](https://github.com/AdvancedPhotonSource/AlphaDiffract)
24
+
25
+ **Automated crystallographic analysis of powder X-ray diffraction data.**
26
+
27
+ AlphaDiffract is a multi-task 1D ConvNeXt model that takes a powder X-ray diffraction (PXRD) pattern and simultaneously predicts:
28
+
29
+ | Output | Description |
30
+ |---|---|
31
+ | **Crystal system** | 7-class classification (Triclinic → Cubic) |
32
+ | **Space group** | 230-class classification |
33
+ | **Lattice parameters** | 6 values: a, b, c (Å), α, β, γ (°) |
34
+
35
+ This release contains a **single model** trained exclusively on
36
+ [Materials Project](https://next-gen.materialsproject.org/) structures
37
+ (publicly available data). It is *not* the 10-model ensemble reported in
38
+ the paper — see [Performance](#performance) for details.
39
+
40
+ ## Quick Start
41
+
42
+ ```bash
43
+ pip install torch safetensors numpy
44
+ ```
45
+
46
+ ```python
47
+ from model import AlphaDiffract
48
+ import torch, numpy as np
49
+
50
+ model = AlphaDiffract.from_pretrained(".", device="cpu")
51
+
52
+ # 8192-point intensity pattern, normalized to [0, 100]
53
+ pattern = np.load("my_pattern.npy").astype(np.float32)
54
+ pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0
55
+ x = torch.from_numpy(pattern).unsqueeze(0)
56
+
57
+ with torch.no_grad():
58
+ out = model(x)
59
+
60
+ cs_probs = torch.softmax(out["cs_logits"], dim=-1)
61
+ sg_probs = torch.softmax(out["sg_logits"], dim=-1)
62
+ lp = out["lp"] # [a, b, c, alpha, beta, gamma]
63
+
64
+ print("Crystal system:", AlphaDiffract.CRYSTAL_SYSTEMS[cs_probs.argmax().item()])
65
+ print("Space group: #", sg_probs.argmax().item() + 1)
66
+ print("Lattice params:", lp[0].tolist())
67
+ ```
68
+
69
+ See `example_inference.py` for a complete runnable example.
70
+
71
+ ## Files
72
+
73
+ | File | Description |
74
+ |---|---|
75
+ | `model.safetensors` | Model weights (safetensors format, ~35 MB) |
76
+ | `model.py` | Standalone model definition (pure PyTorch, no Lightning) |
77
+ | `config.json` | Architecture and training hyperparameters |
78
+ | `maxsub.json` | Space-group subgroup graph (230×230, used as a registered buffer) |
79
+ | `example_inference.py` | End-to-end inference example |
80
+ | `LICENSE` | BSD 3-Clause |
81
+
82
+
83
+ ## Input Format
84
+
85
+ - **Length:** 8192 equally-spaced intensity values
86
+ - **2θ range:** 5–20° (monochromatic, 20 keV)
87
+ - **Preprocessing:** floor negatives at zero, then rescale to [0, 100]
88
+ - **Shape:** `(batch, 8192)` or `(batch, 1, 8192)`
89
+
90
+ ## Architecture
91
+
92
+ 1D ConvNeXt backbone adapted from [Liu et al. (2022)](https://arxiv.org/abs/2201.03545):
93
+
94
+ ```
95
+ Input (8192) → [ConvNeXt Block × 3 with AvgPool] → Flatten (560-d)
96
+ ├─ CS head: MLP 560→2300→1150→7 (crystal system)
97
+ ├─ SG head: MLP 560→2300→1150→230 (space group)
98
+ └─ LP head: MLP 560→512→256→6 (lattice parameters, sigmoid-bounded)
99
+ ```
100
+
101
+ - **Parameters:** 8,734,989
102
+ - **Activation:** GELU
103
+ - **Stochastic depth:** 0.3
104
+ - **Head dropout:** 0.5
105
+
106
+ ## Performance
107
+
108
+ This is a **single model** trained on Materials Project data only (no ICSD).
109
+ Metrics on the best validation checkpoint (epoch 11):
110
+
111
+ | Metric | Simulated Val | RRUFF (experimental) |
112
+ |---|---|---|
113
+ | Crystal system accuracy | 74.88% | 60.35% |
114
+ | Space group accuracy | 57.31% | 38.28% |
115
+ | Lattice parameter MAE | 2.71 | — |
116
+
117
+ The paper reports higher numbers from a 10-model ensemble trained on
118
+ Materials Project + ICSD combined data. This open-weights release covers
119
+ only publicly available training data.
120
+
121
+ ## Training Details
122
+
123
+ | | |
124
+ |---|---|
125
+ | **Data** | ~146k Materials Project structures, 100 GSAS-II simulations each |
126
+ | **Augmentation** | Poisson + Gaussian noise, rescaled to [0, 100] |
127
+ | **Optimizer** | AdamW (lr=2e-4, weight_decay=0.01) |
128
+ | **Scheduler** | CyclicLR (triangular2, 6-epoch half-cycles) |
129
+ | **Loss** | CE (crystal system) + CE + GEMD (space group) + MSE (lattice params) |
130
+ | **Hardware** | 1× NVIDIA H100, float32 |
131
+ | **Batch size** | 64 |
132
+
133
+ ## Citation
134
+
135
+ ```bibtex
136
+ @article{andrejevic2026alphadiffract,
137
+ title = {AlphaDiffract: Automated Crystallographic Analysis of Powder X-ray Diffraction Data},
138
+ author = {Andrejevic, Nina and Du, Ming and Sharma, Hemant and Horwath, James P. and Luo, Aileen and Yin, Xiangyu and Prince, Michael and Toby, Brian H. and Cherukara, Mathew J.},
139
+ year = {2026},
140
+ eprint = {2603.23367},
141
+ archivePrefix = {arXiv},
142
+ primaryClass = {cond-mat.mtrl-sci},
143
+ doi = {10.48550/arXiv.2603.23367},
144
+ url = {https://arxiv.org/abs/2603.23367}
145
+ }
146
+ ```
147
+
148
+ ## License
149
+
150
+ BSD 3-Clause — Copyright 2026 UChicago Argonne, LLC.
151
+
152
+ ## Links
153
+
154
+ - [arXiv: 2603.23367](https://arxiv.org/abs/2603.23367)
155
+ - [GitHub: OpenAlphaDiffract](https://github.com/AdvancedPhotonSource/AlphaDiffract)
config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "alphadiffract",
3
+ "architecture": "ConvNeXt1D-MultiTask",
4
+ "backbone": {
5
+ "dim_in": 8192,
6
+ "channels": [80, 80, 80],
7
+ "kernel_sizes": [100, 50, 25],
8
+ "strides": [5, 5, 5],
9
+ "dropout_rate": 0.3,
10
+ "ramped_dropout_rate": false,
11
+ "block_type": "convnext",
12
+ "pooling_type": "average",
13
+ "final_pool": true,
14
+ "use_batchnorm": false,
15
+ "activation": "gelu",
16
+ "output_type": "flatten",
17
+ "layer_scale_init_value": 0.0,
18
+ "drop_path_rate": 0.3
19
+ },
20
+ "heads": {
21
+ "head_dropout": 0.5,
22
+ "cs_hidden": [2300, 1150],
23
+ "sg_hidden": [2300, 1150],
24
+ "lp_hidden": [512, 256]
25
+ },
26
+ "tasks": {
27
+ "num_cs_classes": 7,
28
+ "num_sg_classes": 230,
29
+ "num_lp_outputs": 6,
30
+ "lp_bounds_min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
31
+ "lp_bounds_max": [500.0, 500.0, 500.0, 180.0, 180.0, 180.0],
32
+ "bound_lp_with_sigmoid": true
33
+ },
34
+ "training": {
35
+ "optimizer": "AdamW",
36
+ "lr": 0.0002,
37
+ "weight_decay": 0.01,
38
+ "scheduler": "CyclicLR",
39
+ "scheduler_mode": "triangular2",
40
+ "batch_size": 64,
41
+ "max_epochs": 100,
42
+ "precision": "float32",
43
+ "gemd_mu": 1.0,
44
+ "lambda_cs": 1.0,
45
+ "lambda_sg": 1.0,
46
+ "lambda_lp": 1.0
47
+ },
48
+ "preprocessing": {
49
+ "input_length": 8192,
50
+ "floor_at_zero": true,
51
+ "normalize_range": [0.0, 100.0],
52
+ "noise_poisson_range": [1.0, 100.0],
53
+ "noise_gaussian_range": [0.001, 0.1]
54
+ }
55
+ }
example_inference.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example: load AlphaDiffract and run inference on a PXRD pattern.
3
+
4
+ Requirements:
5
+ pip install torch safetensors numpy
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ from model import AlphaDiffract
11
+
12
+ # 1. Load model ---------------------------------------------------------------
13
+ model = AlphaDiffract.from_pretrained(".", device="cpu") # or "cuda"
14
+
15
+ # 2. Prepare input -------------------------------------------------------------
16
+ # The model expects an 8192-point PXRD intensity pattern normalized to [0, 100].
17
+ # Replace this with your own data.
18
+ pattern = np.random.rand(8192).astype(np.float32) # placeholder
19
+
20
+ # Normalize to [0, 100]
21
+ pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0
22
+ x = torch.from_numpy(pattern).unsqueeze(0) # shape: (1, 8192)
23
+
24
+ # 3. Inference -----------------------------------------------------------------
25
+ with torch.no_grad():
26
+ out = model(x)
27
+
28
+ cs_probs = torch.softmax(out["cs_logits"], dim=-1)
29
+ sg_probs = torch.softmax(out["sg_logits"], dim=-1)
30
+ lp = out["lp"]
31
+
32
+ # 4. Results -------------------------------------------------------------------
33
+ cs_idx = cs_probs.argmax(dim=-1).item()
34
+ sg_idx = sg_probs.argmax(dim=-1).item()
35
+
36
+ print(f"Crystal system : {AlphaDiffract.CRYSTAL_SYSTEMS[cs_idx]} "
37
+ f"({cs_probs[0, cs_idx]:.1%})")
38
+ print(f"Space group : #{sg_idx + 1} ({sg_probs[0, sg_idx]:.1%})")
39
+
40
+ labels = ["a", "b", "c", "alpha", "beta", "gamma"]
41
+ units = ["A", "A", "A", "deg", "deg", "deg"]
42
+ print("Lattice params :")
43
+ for name, val, unit in zip(labels, lp[0].tolist(), units):
44
+ print(f" {name:>5s} = {val:8.3f} {unit}")
maxsub.json ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": [1],
3
+ "2": [1,2],
4
+ "3": [1,3,4,5],
5
+ "4": [1,4],
6
+ "5": [1,3,4,5],
7
+ "6": [1,6,7,8],
8
+ "7": [1,7,9],
9
+ "8": [1,6,7,8,9],
10
+ "9": [1,7,9],
11
+ "10": [2,3,6,10,11,12,13],
12
+ "11": [2,4,6,11,14],
13
+ "12": [2,5,8,10,11,12,13,14,15],
14
+ "13": [2,3,7,13,14,15],
15
+ "14": [2,4,7,14],
16
+ "15": [2,5,9,13,14,15],
17
+ "16": [3,16,17,21,22],
18
+ "17": [3,4,17,18,20],
19
+ "18": [3,4,18,19],
20
+ "19": [4,19],
21
+ "20": [4,5,17,18,19,20],
22
+ "21": [3,5,16,17,18,20,21,23,24],
23
+ "22": [5,20,21,22],
24
+ "23": [5,16,18,23],
25
+ "24": [5,17,19,24],
26
+ "25": [3,6,25,26,27,28,35,38,39,42],
27
+ "26": [4,6,7,26,29,31,36],
28
+ "27": [3,7,27,30,37],
29
+ "28": [3,6,7,28,29,30,31,32,40,41],
30
+ "29": [4,7,29,33],
31
+ "30": [3,7,30,34],
32
+ "31": [4,6,7,31,33],
33
+ "32": [3,7,32,33,34],
34
+ "33": [4,7,33],
35
+ "34": [3,7,34,43],
36
+ "35": [3,8,25,28,32,35,36,37,44,45,46],
37
+ "36": [4,8,9,26,29,31,33,36],
38
+ "37": [3,9,27,30,34,37],
39
+ "38": [5,6,8,25,26,30,31,38,40,44,46],
40
+ "39": [5,7,8,26,27,28,29,39,41,45,46],
41
+ "40": [5,6,9,28,31,33,34,40],
42
+ "41": [5,7,9,29,30,32,33,41],
43
+ "42": [5,8,35,36,37,38,39,40,41,42],
44
+ "43": [5,9,43],
45
+ "44": [5,8,25,31,34,44],
46
+ "45": [5,9,27,29,32,45],
47
+ "46": [5,8,9,26,28,30,33,46],
48
+ "47": [10,16,25,47,49,51,65,67,69],
49
+ "48": [13,16,34,48,70],
50
+ "49": [10,13,16,27,28,49,50,53,54,66,68],
51
+ "50": [13,16,30,32,48,50,52],
52
+ "51": [10,11,13,17,25,26,28,51,53,54,55,57,59,63,64],
53
+ "52": [13,14,17,30,33,34,52],
54
+ "53": [10,13,14,17,28,30,31,52,53,58,60],
55
+ "54": [13,14,17,27,29,32,52,54,56,60],
56
+ "55": [10,14,18,26,32,55,58,62],
57
+ "56": [13,14,18,27,33,56],
58
+ "57": [11,13,14,18,26,28,29,57,60,61,62],
59
+ "58": [10,14,18,31,34,58],
60
+ "59": [11,13,18,25,31,56,59,62],
61
+ "60": [13,14,18,29,30,33,60],
62
+ "61": [14,19,29,61],
63
+ "62": [11,14,19,26,31,33,62],
64
+ "63": [11,12,15,20,36,38,40,51,52,57,58,59,60,62,63],
65
+ "64": [12,14,15,20,36,39,41,53,54,55,56,57,60,61,62,64],
66
+ "65": [10,12,21,35,38,47,50,51,53,55,59,63,65,66,71,72,74],
67
+ "66": [10,15,21,37,40,48,49,52,53,56,58,66],
68
+ "67": [12,13,21,35,39,49,51,54,57,64,67,68,72,73,74],
69
+ "68": [13,15,21,37,41,50,52,54,60,68],
70
+ "69": [12,22,42,63,64,65,66,67,68,69],
71
+ "70": [15,22,43,70],
72
+ "71": [12,23,44,47,48,58,59,71],
73
+ "72": [12,15,23,45,46,49,50,55,56,57,60,72],
74
+ "73": [15,24,45,54,61,73],
75
+ "74": [12,15,24,44,46,51,52,53,62,74],
76
+ "75": [3,75,77,79],
77
+ "76": [4,76,78],
78
+ "77": [3,76,77,78,80],
79
+ "78": [4,76,78],
80
+ "79": [5,75,77,79],
81
+ "80": [5,76,78,80],
82
+ "81": [3,81,82],
83
+ "82": [5,81,82],
84
+ "83": [10,75,81,83,84,85,87],
85
+ "84": [10,77,81,84,86],
86
+ "85": [13,75,81,85,86],
87
+ "86": [13,77,81,86,88],
88
+ "87": [12,79,82,83,84,85,86,87],
89
+ "88": [15,80,82,88],
90
+ "89": [16,21,75,89,90,93,97],
91
+ "90": [18,21,75,90,94],
92
+ "91": [17,20,76,91,92,95],
93
+ "92": [19,20,76,92,96],
94
+ "93": [16,21,77,91,93,94,95,98],
95
+ "94": [18,21,77,92,94,96],
96
+ "95": [17,20,78,91,95,96],
97
+ "96": [19,20,78,92,96],
98
+ "97": [22,23,79,89,90,93,94,97],
99
+ "98": [22,24,80,91,92,95,96,98],
100
+ "99": [25,35,75,99,100,101,103,105,107,108],
101
+ "100": [32,35,75,100,102,104,106],
102
+ "101": [27,35,77,101,105,106],
103
+ "102": [34,35,77,102,109,110],
104
+ "103": [27,37,75,103,104],
105
+ "104": [34,37,75,104],
106
+ "105": [25,37,77,101,102,105],
107
+ "106": [32,37,77,106],
108
+ "107": [42,44,79,99,102,104,105,107],
109
+ "108": [42,45,79,100,101,103,106,108],
110
+ "109": [43,44,80,109],
111
+ "110": [43,45,80,110],
112
+ "111": [16,35,81,111,112,115,117,119,120],
113
+ "112": [16,37,81,112,116,118],
114
+ "113": [18,35,81,113,114],
115
+ "114": [18,37,81,114],
116
+ "115": [21,25,81,111,113,115,116,121],
117
+ "116": [21,27,81,112,114,116],
118
+ "117": [21,32,81,117,118],
119
+ "118": [21,34,81,118,122],
120
+ "119": [22,44,82,115,118,119],
121
+ "120": [22,45,82,116,117,120],
122
+ "121": [23,42,82,111,112,113,114,121],
123
+ "122": [24,43,82,122],
124
+ "123": [47,65,83,89,99,111,115,123,124,125,127,129,131,132,139,140],
125
+ "124": [49,66,83,89,103,112,116,124,126,128,130],
126
+ "125": [50,67,85,89,100,111,117,125,126,133,134],
127
+ "126": [48,68,85,89,104,112,118,126],
128
+ "127": [55,65,83,90,100,113,117,127,128,135,136],
129
+ "128": [58,66,83,90,104,114,118,128],
130
+ "129": [59,67,85,90,99,113,115,129,130,137,138],
131
+ "130": [56,68,85,90,103,114,116,130],
132
+ "131": [47,66,84,93,105,112,115,131,132,134,136,138],
133
+ "132": [49,65,84,93,101,111,116,131,132,133,135,137],
134
+ "133": [50,68,86,93,106,112,117,133],
135
+ "134": [48,67,86,93,102,111,118,134,141,142],
136
+ "135": [55,66,84,94,106,114,117,135],
137
+ "136": [58,65,84,94,102,113,118,136],
138
+ "137": [59,68,86,94,105,114,115,137],
139
+ "138": [56,67,86,94,101,113,116,138],
140
+ "139": [69,71,87,97,107,119,121,123,126,128,129,131,134,136,137,139],
141
+ "140": [69,72,87,97,108,120,121,124,125,127,130,132,133,135,138,140],
142
+ "141": [70,74,88,98,109,119,122,141],
143
+ "142": [70,73,88,98,110,120,122,142],
144
+ "143": [1,143,144,145,146],
145
+ "144": [1,144,145],
146
+ "145": [1,144,145],
147
+ "146": [1,143,144,145,146],
148
+ "147": [2,143,147,148],
149
+ "148": [2,146,147,148],
150
+ "149": [5,143,149,150,151,153,155],
151
+ "150": [5,143,149,150,152,154],
152
+ "151": [5,144,151,152,153],
153
+ "152": [5,144,151,152,154],
154
+ "153": [5,145,151,153,154],
155
+ "154": [5,145,152,153,154],
156
+ "155": [5,146,150,152,154,155],
157
+ "156": [8,143,156,157,158],
158
+ "157": [8,143,156,157,159,160],
159
+ "158": [9,143,158,159],
160
+ "159": [9,143,158,159,161],
161
+ "160": [8,146,156,160,161],
162
+ "161": [9,146,158,161],
163
+ "162": [12,147,149,157,162,163,164,166],
164
+ "163": [15,147,149,159,163,165,167],
165
+ "164": [12,147,150,156,162,164,165],
166
+ "165": [15,147,150,158,163,165],
167
+ "166": [12,148,155,160,164,166,167],
168
+ "167": [15,148,155,161,165,167],
169
+ "168": [3,143,168,171,172,173],
170
+ "169": [4,144,169,170],
171
+ "170": [4,145,169,170],
172
+ "171": [3,145,169,171,172],
173
+ "172": [3,144,170,171,172],
174
+ "173": [4,143,169,170,173],
175
+ "174": [6,143,174],
176
+ "175": [10,147,168,174,175,176],
177
+ "176": [11,147,173,174,176],
178
+ "177": [21,149,150,168,177,180,181,182],
179
+ "178": [20,151,152,169,178,179],
180
+ "179": [20,153,154,170,178,179],
181
+ "180": [21,153,154,171,178,180,181],
182
+ "181": [21,151,152,172,179,180,181],
183
+ "182": [20,149,150,173,178,179,182],
184
+ "183": [35,156,157,168,183,184,185,186],
185
+ "184": [37,158,159,168,184],
186
+ "185": [36,157,158,173,185,186],
187
+ "186": [36,156,159,173,185,186],
188
+ "187": [38,149,156,174,187,188,189],
189
+ "188": [40,149,158,174,188,190],
190
+ "189": [38,150,157,174,187,189,190],
191
+ "190": [40,150,159,174,188,190],
192
+ "191": [65,162,164,175,177,183,187,189,191,192,193,194],
193
+ "192": [66,163,165,175,177,184,188,190,192],
194
+ "193": [63,162,165,176,182,185,188,189,193,194],
195
+ "194": [63,163,164,176,182,186,187,190,193,194],
196
+ "195": [16,146,196,197,199],
197
+ "196": [22,146,195,198],
198
+ "197": [23,146,195],
199
+ "198": [19,146],
200
+ "199": [24,146,198],
201
+ "200": [47,148,195,202,204,206],
202
+ "201": [48,148,195,203],
203
+ "202": [69,148,196,200,201,205],
204
+ "203": [70,148,196],
205
+ "204": [71,148,197,200,201],
206
+ "205": [61,148,198],
207
+ "206": [73,148,199,205],
208
+ "207": [89,155,195,209,211],
209
+ "208": [93,155,195,210,214],
210
+ "209": [97,155,196,207,208],
211
+ "210": [98,155,196,212,213],
212
+ "211": [97,155,197,207,208],
213
+ "212": [96,155,198],
214
+ "213": [92,155,198],
215
+ "214": [98,155,199,212,213],
216
+ "215": [111,160,195,216,217,219],
217
+ "216": [119,160,196,215],
218
+ "217": [121,160,197,215,218],
219
+ "218": [112,161,195,220],
220
+ "219": [120,161,196,218],
221
+ "220": [122,161,199],
222
+ "221": [123,166,200,207,215,225,226,229],
223
+ "222": [126,167,201,207,218],
224
+ "223": [131,167,200,208,218,230],
225
+ "224": [134,166,201,208,215,227,228],
226
+ "225": [139,166,202,209,216,221,224],
227
+ "226": [140,167,202,209,219,222,223],
228
+ "227": [141,166,203,210,216],
229
+ "228": [142,167,203,210,219],
230
+ "229": [139,166,204,211,217,221,222,223,224],
231
+ "230": [142,167,206,214,220]
232
+ }
model.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is self-contained: download it alongside `model.safetensors`,
3
+ `config.json`, and `maxsub.json` to load and run the model.
4
+ """
5
+
6
+ import json
7
+ from collections import deque
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Utility: DropPath (Stochastic Depth)
18
+ # ---------------------------------------------------------------------------
19
+ def drop_path(
20
+ x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
21
+ ) -> torch.Tensor:
22
+ if drop_prob == 0.0 or not training:
23
+ return x
24
+ keep_prob = 1 - drop_prob
25
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
26
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
27
+ random_tensor = random_tensor.floor()
28
+ return x.div(keep_prob) * random_tensor
29
+
30
+
31
+ class DropPath(nn.Module):
32
+ def __init__(self, drop_prob: float = 0.0):
33
+ super().__init__()
34
+ self.drop_prob = drop_prob
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ return drop_path(x, self.drop_prob, self.training)
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # ConvNeXt 1D Block
42
+ # ---------------------------------------------------------------------------
43
+ class ConvNeXtBlock1D(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ kernel_size: int,
48
+ drop_path: float,
49
+ layer_scale_init_value: float,
50
+ activation: nn.Module,
51
+ ):
52
+ super().__init__()
53
+ self.dwconv = nn.Conv1d(
54
+ dim, dim, kernel_size=kernel_size, padding="same", groups=dim
55
+ )
56
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
57
+ self.act = activation() if isinstance(activation, type) else activation
58
+ self.pwconv2 = nn.Linear(4 * dim, dim)
59
+ self.gamma = (
60
+ nn.Parameter(layer_scale_init_value * torch.ones(dim))
61
+ if layer_scale_init_value > 0
62
+ else None
63
+ )
64
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ shortcut = x
68
+ x = self.dwconv(x)
69
+ x = x.permute(0, 2, 1)
70
+ x = self.pwconv1(x)
71
+ x = self.act(x)
72
+ x = self.pwconv2(x)
73
+ if self.gamma is not None:
74
+ x = x * self.gamma
75
+ x = x.permute(0, 2, 1)
76
+ x = shortcut + self.drop_path(x)
77
+ return x
78
+
79
+
80
+ class ConvNextBlock1DAdaptor(nn.Module):
81
+ def __init__(
82
+ self,
83
+ in_channels: int,
84
+ out_channels: int,
85
+ kernel_size: int,
86
+ stride: int,
87
+ dropout: float,
88
+ use_batchnorm: bool,
89
+ activation: nn.Module,
90
+ layer_scale_init_value: float,
91
+ drop_path_rate: float,
92
+ block_type: str,
93
+ ):
94
+ super().__init__()
95
+ if in_channels != out_channels:
96
+ act = activation() if isinstance(activation, type) else activation
97
+ self.pwconv = nn.Sequential(nn.Linear(in_channels, out_channels), act)
98
+ else:
99
+ self.pwconv = None
100
+
101
+ if block_type == "convnext":
102
+ self.block = ConvNeXtBlock1D(
103
+ dim=out_channels,
104
+ kernel_size=kernel_size,
105
+ drop_path=drop_path_rate,
106
+ layer_scale_init_value=layer_scale_init_value,
107
+ activation=activation,
108
+ )
109
+ else:
110
+ self.block = None
111
+
112
+ if stride > 1:
113
+ self.reduction_pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
114
+ else:
115
+ self.reduction_pool = None
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ if self.pwconv is not None:
119
+ x = x.permute(0, 2, 1)
120
+ x = self.pwconv(x)
121
+ x = x.permute(0, 2, 1)
122
+ if self.block is not None:
123
+ x = self.block(x)
124
+ if self.reduction_pool is not None:
125
+ x = self.reduction_pool(x)
126
+ return x
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # MLP head builder
131
+ # ---------------------------------------------------------------------------
132
+ def make_mlp(
133
+ input_dim: int,
134
+ hidden_dims: Optional[Tuple[int, ...]],
135
+ output_dim: int,
136
+ dropout: float = 0.2,
137
+ output_activation: Optional[nn.Module] = None,
138
+ ) -> nn.Module:
139
+ layers: List[nn.Module] = []
140
+ last = input_dim
141
+ if hidden_dims is not None and len(hidden_dims) > 0:
142
+ for hd in hidden_dims:
143
+ layers.extend([nn.Linear(last, hd), nn.ReLU()])
144
+ if dropout and dropout > 0:
145
+ layers.append(nn.Dropout(dropout))
146
+ last = hd
147
+ layers.append(nn.Linear(last, output_dim))
148
+ if output_activation is not None:
149
+ layers.append(output_activation)
150
+ return nn.Sequential(*layers)
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Backbone
155
+ # ---------------------------------------------------------------------------
156
+ class MultiscaleCNNBackbone1D(nn.Module):
157
+ def __init__(
158
+ self,
159
+ dim_in: int,
160
+ channels: Tuple[int, ...],
161
+ kernel_sizes: Tuple[int, ...],
162
+ strides: Tuple[int, ...],
163
+ dropout_rate: float,
164
+ ramped_dropout_rate: bool,
165
+ block_type: str,
166
+ pooling_type: str,
167
+ final_pool: bool,
168
+ use_batchnorm: bool,
169
+ activation: nn.Module,
170
+ output_type: str,
171
+ layer_scale_init_value: float,
172
+ drop_path_rate: float,
173
+ ):
174
+ super().__init__()
175
+ assert len(channels) == len(kernel_sizes) == len(strides)
176
+ self.dim_in = dim_in
177
+ self.output_type = output_type
178
+
179
+ if ramped_dropout_rate:
180
+ dropout_per_stage = torch.linspace(
181
+ 0.0, dropout_rate, steps=len(channels)
182
+ ).tolist()
183
+ else:
184
+ dropout_per_stage = [dropout_rate] * len(channels)
185
+
186
+ if pooling_type == "average":
187
+ pool_cls = nn.AvgPool1d
188
+ pool_kwargs = {"kernel_size": 3, "stride": 2}
189
+ elif pooling_type == "max":
190
+ pool_cls = nn.MaxPool1d
191
+ pool_kwargs = {"kernel_size": 2, "stride": 2}
192
+ else:
193
+ raise ValueError(f"Invalid pooling_type '{pooling_type}'")
194
+
195
+ layers: List[nn.Module] = []
196
+ in_ch = 1
197
+ for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)):
198
+ stage_block = ConvNextBlock1DAdaptor(
199
+ in_channels=in_ch,
200
+ out_channels=out_ch,
201
+ kernel_size=k,
202
+ stride=s,
203
+ dropout=dropout_per_stage[i],
204
+ use_batchnorm=use_batchnorm,
205
+ activation=activation,
206
+ layer_scale_init_value=layer_scale_init_value,
207
+ drop_path_rate=drop_path_rate,
208
+ block_type=block_type,
209
+ )
210
+ layers.append(stage_block)
211
+ if i < len(channels) - 1 or final_pool:
212
+ layers.append(pool_cls(**pool_kwargs))
213
+ in_ch = out_ch
214
+
215
+ self.net = nn.Sequential(*layers)
216
+
217
+ if self.output_type == "gap":
218
+ self.dim_output = channels[-1]
219
+ elif self.output_type == "flatten":
220
+ with torch.no_grad():
221
+ dummy = torch.zeros(1, 1, self.dim_in)
222
+ out = self.net(dummy)
223
+ self.dim_output = int(out.shape[1] * out.shape[2])
224
+ else:
225
+ raise ValueError(f"Invalid output_type '{self.output_type}'")
226
+
227
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
228
+ if x.ndim == 2:
229
+ x = x[:, None, :]
230
+ x = self.net(x)
231
+ if self.output_type == "gap":
232
+ x = x.mean(dim=-1)
233
+ else:
234
+ x = x.reshape(x.shape[0], -1)
235
+ return x
236
+
237
+
238
+ # ---------------------------------------------------------------------------
239
+ # GEMD distance-matrix utilities
240
+ # ---------------------------------------------------------------------------
241
+ def _build_distance_matrix_from_maxsub_lut(
242
+ maxsub_lut: Dict[str, List[int]],
243
+ num_sg_classes: int,
244
+ ) -> torch.Tensor:
245
+ adjacency: List[set] = [set() for _ in range(num_sg_classes)]
246
+ for key, neighbors in maxsub_lut.items():
247
+ src = int(key) - 1
248
+ for raw_dst in neighbors:
249
+ dst = int(raw_dst) - 1
250
+ adjacency[src].add(dst)
251
+ adjacency[dst].add(src)
252
+
253
+ distance_matrix = torch.zeros(
254
+ (num_sg_classes, num_sg_classes), dtype=torch.float32
255
+ )
256
+ for src in range(num_sg_classes):
257
+ dists = [-1] * num_sg_classes
258
+ dists[src] = 0
259
+ queue = deque([src])
260
+ while queue:
261
+ cur = queue.popleft()
262
+ for nxt in adjacency[cur]:
263
+ if dists[nxt] == -1:
264
+ dists[nxt] = dists[cur] + 1
265
+ queue.append(nxt)
266
+ distance_matrix[src] = torch.tensor(dists, dtype=torch.float32)
267
+ return distance_matrix
268
+
269
+
270
+ def load_gemd_distance_matrix(
271
+ path: str, num_sg_classes: int = 230
272
+ ) -> torch.Tensor:
273
+ with open(path, "r", encoding="utf-8") as f:
274
+ payload: Any = json.load(f)
275
+ if isinstance(payload, dict) and all(str(k).isdigit() for k in payload.keys()):
276
+ return _build_distance_matrix_from_maxsub_lut(payload, num_sg_classes)
277
+ elif isinstance(payload, list):
278
+ return torch.as_tensor(payload, dtype=torch.float32)
279
+ raise ValueError(f"Could not parse GEMD data from {path}")
280
+
281
+
282
+ # ---------------------------------------------------------------------------
283
+ # Full model
284
+ # ---------------------------------------------------------------------------
285
+ class AlphaDiffract(nn.Module):
286
+ """
287
+ AlphaDiffract: multi-task 1D ConvNeXt for powder X-ray diffraction
288
+ pattern analysis.
289
+
290
+ Predicts crystal system (7 classes), space group (230 classes), and
291
+ lattice parameters (6 values: a, b, c, alpha, beta, gamma).
292
+ """
293
+
294
+ CRYSTAL_SYSTEMS = [
295
+ "Triclinic",
296
+ "Monoclinic",
297
+ "Orthorhombic",
298
+ "Tetragonal",
299
+ "Trigonal",
300
+ "Hexagonal",
301
+ "Cubic",
302
+ ]
303
+
304
+ def __init__(self, config: dict, maxsub_path: Optional[str] = None):
305
+ super().__init__()
306
+ bb = config["backbone"]
307
+ heads = config["heads"]
308
+ tasks = config["tasks"]
309
+
310
+ activation = nn.GELU
311
+
312
+ self.backbone = MultiscaleCNNBackbone1D(
313
+ dim_in=bb["dim_in"],
314
+ channels=tuple(bb["channels"]),
315
+ kernel_sizes=tuple(bb["kernel_sizes"]),
316
+ strides=tuple(bb["strides"]),
317
+ dropout_rate=bb["dropout_rate"],
318
+ ramped_dropout_rate=bb["ramped_dropout_rate"],
319
+ block_type=bb["block_type"],
320
+ pooling_type=bb["pooling_type"],
321
+ final_pool=bb["final_pool"],
322
+ use_batchnorm=bb["use_batchnorm"],
323
+ activation=activation,
324
+ output_type=bb["output_type"],
325
+ layer_scale_init_value=bb["layer_scale_init_value"],
326
+ drop_path_rate=bb["drop_path_rate"],
327
+ )
328
+ feat_dim = self.backbone.dim_output
329
+
330
+ self.cs_head = make_mlp(
331
+ feat_dim, tuple(heads["cs_hidden"]), tasks["num_cs_classes"],
332
+ dropout=heads["head_dropout"],
333
+ )
334
+ self.sg_head = make_mlp(
335
+ feat_dim, tuple(heads["sg_hidden"]), tasks["num_sg_classes"],
336
+ dropout=heads["head_dropout"],
337
+ )
338
+ self.lp_head = make_mlp(
339
+ feat_dim, tuple(heads["lp_hidden"]), tasks["num_lp_outputs"],
340
+ dropout=heads["head_dropout"],
341
+ )
342
+
343
+ self.bound_lp_with_sigmoid = tasks["bound_lp_with_sigmoid"]
344
+ self.register_buffer(
345
+ "lp_min",
346
+ torch.tensor(tasks["lp_bounds_min"], dtype=torch.float32),
347
+ )
348
+ self.register_buffer(
349
+ "lp_max",
350
+ torch.tensor(tasks["lp_bounds_max"], dtype=torch.float32),
351
+ )
352
+
353
+ if maxsub_path is not None:
354
+ gemd = load_gemd_distance_matrix(maxsub_path)
355
+ self.register_buffer("gemd_distance_matrix", gemd)
356
+ else:
357
+ self.gemd_distance_matrix = None
358
+
359
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
360
+ """
361
+ Args:
362
+ x: PXRD pattern tensor of shape ``(batch, 8192)`` or
363
+ ``(batch, 1, 8192)``, intensity-normalized to [0, 100].
364
+
365
+ Returns:
366
+ Dict with keys ``cs_logits``, ``sg_logits``, ``lp``.
367
+ """
368
+ feats = self.backbone(x)
369
+ cs_logits = self.cs_head(feats)
370
+ sg_logits = self.sg_head(feats)
371
+ lp = self.lp_head(feats)
372
+ if self.bound_lp_with_sigmoid:
373
+ lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min
374
+ return {"cs_logits": cs_logits, "sg_logits": sg_logits, "lp": lp}
375
+
376
+ # -- convenience loaders ------------------------------------------------
377
+
378
+ @classmethod
379
+ def from_pretrained(
380
+ cls,
381
+ model_dir: str,
382
+ device: str = "cpu",
383
+ ) -> "AlphaDiffract":
384
+ """Load model from a directory containing config.json,
385
+ model.safetensors, and maxsub.json."""
386
+ model_dir = Path(model_dir)
387
+ with open(model_dir / "config.json", "r") as f:
388
+ config = json.load(f)
389
+
390
+ maxsub_path = model_dir / "maxsub.json"
391
+ model = cls(
392
+ config,
393
+ maxsub_path=str(maxsub_path) if maxsub_path.exists() else None,
394
+ )
395
+
396
+ weights_path = model_dir / "model.safetensors"
397
+ if weights_path.exists():
398
+ from safetensors.torch import load_file
399
+ state_dict = load_file(str(weights_path), device=device)
400
+ else:
401
+ # Fallback to PyTorch format
402
+ pt_path = model_dir / "model.pt"
403
+ state_dict = torch.load(str(pt_path), map_location=device, weights_only=True)
404
+
405
+ model.load_state_dict(state_dict)
406
+ model.to(device)
407
+ model.eval()
408
+ return model
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d81f2716ef8c78683d52ac51afff3eaf160c0b2b410685d0c90299bc2fd58ed
3
+ size 35155332