Upload folder using huggingface_hub
Browse files- LICENSE +13 -0
- README.md +155 -0
- config.json +55 -0
- example_inference.py +44 -0
- maxsub.json +232 -0
- model.py +408 -0
- model.safetensors +3 -0
LICENSE
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 3-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright 2026 UChicago Argonne, LLC. All rights reserved.
|
| 4 |
+
|
| 5 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 6 |
+
|
| 7 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 8 |
+
|
| 9 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 10 |
+
|
| 11 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
| 12 |
+
|
| 13 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: bsd-3-clause
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- pytorch
|
| 7 |
+
- materials-science
|
| 8 |
+
- crystallography
|
| 9 |
+
- x-ray-diffraction
|
| 10 |
+
- pxrd
|
| 11 |
+
- convnext
|
| 12 |
+
- arxiv:2603.23367
|
| 13 |
+
datasets:
|
| 14 |
+
- materials-project
|
| 15 |
+
metrics:
|
| 16 |
+
- accuracy
|
| 17 |
+
- mae
|
| 18 |
+
pipeline_tag: other
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
# AlphaDiffract — Open Weights
|
| 22 |
+
|
| 23 |
+
[arXiv](https://arxiv.org/abs/2603.23367) | [GitHub](https://github.com/AdvancedPhotonSource/AlphaDiffract)
|
| 24 |
+
|
| 25 |
+
**Automated crystallographic analysis of powder X-ray diffraction data.**
|
| 26 |
+
|
| 27 |
+
AlphaDiffract is a multi-task 1D ConvNeXt model that takes a powder X-ray diffraction (PXRD) pattern and simultaneously predicts:
|
| 28 |
+
|
| 29 |
+
| Output | Description |
|
| 30 |
+
|---|---|
|
| 31 |
+
| **Crystal system** | 7-class classification (Triclinic → Cubic) |
|
| 32 |
+
| **Space group** | 230-class classification |
|
| 33 |
+
| **Lattice parameters** | 6 values: a, b, c (Å), α, β, γ (°) |
|
| 34 |
+
|
| 35 |
+
This release contains a **single model** trained exclusively on
|
| 36 |
+
[Materials Project](https://next-gen.materialsproject.org/) structures
|
| 37 |
+
(publicly available data). It is *not* the 10-model ensemble reported in
|
| 38 |
+
the paper — see [Performance](#performance) for details.
|
| 39 |
+
|
| 40 |
+
## Quick Start
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
pip install torch safetensors numpy
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
from model import AlphaDiffract
|
| 48 |
+
import torch, numpy as np
|
| 49 |
+
|
| 50 |
+
model = AlphaDiffract.from_pretrained(".", device="cpu")
|
| 51 |
+
|
| 52 |
+
# 8192-point intensity pattern, normalized to [0, 100]
|
| 53 |
+
pattern = np.load("my_pattern.npy").astype(np.float32)
|
| 54 |
+
pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0
|
| 55 |
+
x = torch.from_numpy(pattern).unsqueeze(0)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
out = model(x)
|
| 59 |
+
|
| 60 |
+
cs_probs = torch.softmax(out["cs_logits"], dim=-1)
|
| 61 |
+
sg_probs = torch.softmax(out["sg_logits"], dim=-1)
|
| 62 |
+
lp = out["lp"] # [a, b, c, alpha, beta, gamma]
|
| 63 |
+
|
| 64 |
+
print("Crystal system:", AlphaDiffract.CRYSTAL_SYSTEMS[cs_probs.argmax().item()])
|
| 65 |
+
print("Space group: #", sg_probs.argmax().item() + 1)
|
| 66 |
+
print("Lattice params:", lp[0].tolist())
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
See `example_inference.py` for a complete runnable example.
|
| 70 |
+
|
| 71 |
+
## Files
|
| 72 |
+
|
| 73 |
+
| File | Description |
|
| 74 |
+
|---|---|
|
| 75 |
+
| `model.safetensors` | Model weights (safetensors format, ~35 MB) |
|
| 76 |
+
| `model.py` | Standalone model definition (pure PyTorch, no Lightning) |
|
| 77 |
+
| `config.json` | Architecture and training hyperparameters |
|
| 78 |
+
| `maxsub.json` | Space-group subgroup graph (230×230, used as a registered buffer) |
|
| 79 |
+
| `example_inference.py` | End-to-end inference example |
|
| 80 |
+
| `LICENSE` | BSD 3-Clause |
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
## Input Format
|
| 84 |
+
|
| 85 |
+
- **Length:** 8192 equally-spaced intensity values
|
| 86 |
+
- **2θ range:** 5–20° (monochromatic, 20 keV)
|
| 87 |
+
- **Preprocessing:** floor negatives at zero, then rescale to [0, 100]
|
| 88 |
+
- **Shape:** `(batch, 8192)` or `(batch, 1, 8192)`
|
| 89 |
+
|
| 90 |
+
## Architecture
|
| 91 |
+
|
| 92 |
+
1D ConvNeXt backbone adapted from [Liu et al. (2022)](https://arxiv.org/abs/2201.03545):
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
Input (8192) → [ConvNeXt Block × 3 with AvgPool] → Flatten (560-d)
|
| 96 |
+
├─ CS head: MLP 560→2300→1150→7 (crystal system)
|
| 97 |
+
├─ SG head: MLP 560→2300→1150→230 (space group)
|
| 98 |
+
└─ LP head: MLP 560→512→256→6 (lattice parameters, sigmoid-bounded)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
- **Parameters:** 8,734,989
|
| 102 |
+
- **Activation:** GELU
|
| 103 |
+
- **Stochastic depth:** 0.3
|
| 104 |
+
- **Head dropout:** 0.5
|
| 105 |
+
|
| 106 |
+
## Performance
|
| 107 |
+
|
| 108 |
+
This is a **single model** trained on Materials Project data only (no ICSD).
|
| 109 |
+
Metrics on the best validation checkpoint (epoch 11):
|
| 110 |
+
|
| 111 |
+
| Metric | Simulated Val | RRUFF (experimental) |
|
| 112 |
+
|---|---|---|
|
| 113 |
+
| Crystal system accuracy | 74.88% | 60.35% |
|
| 114 |
+
| Space group accuracy | 57.31% | 38.28% |
|
| 115 |
+
| Lattice parameter MAE | 2.71 | — |
|
| 116 |
+
|
| 117 |
+
The paper reports higher numbers from a 10-model ensemble trained on
|
| 118 |
+
Materials Project + ICSD combined data. This open-weights release covers
|
| 119 |
+
only publicly available training data.
|
| 120 |
+
|
| 121 |
+
## Training Details
|
| 122 |
+
|
| 123 |
+
| | |
|
| 124 |
+
|---|---|
|
| 125 |
+
| **Data** | ~146k Materials Project structures, 100 GSAS-II simulations each |
|
| 126 |
+
| **Augmentation** | Poisson + Gaussian noise, rescaled to [0, 100] |
|
| 127 |
+
| **Optimizer** | AdamW (lr=2e-4, weight_decay=0.01) |
|
| 128 |
+
| **Scheduler** | CyclicLR (triangular2, 6-epoch half-cycles) |
|
| 129 |
+
| **Loss** | CE (crystal system) + CE + GEMD (space group) + MSE (lattice params) |
|
| 130 |
+
| **Hardware** | 1× NVIDIA H100, float32 |
|
| 131 |
+
| **Batch size** | 64 |
|
| 132 |
+
|
| 133 |
+
## Citation
|
| 134 |
+
|
| 135 |
+
```bibtex
|
| 136 |
+
@article{andrejevic2026alphadiffract,
|
| 137 |
+
title = {AlphaDiffract: Automated Crystallographic Analysis of Powder X-ray Diffraction Data},
|
| 138 |
+
author = {Andrejevic, Nina and Du, Ming and Sharma, Hemant and Horwath, James P. and Luo, Aileen and Yin, Xiangyu and Prince, Michael and Toby, Brian H. and Cherukara, Mathew J.},
|
| 139 |
+
year = {2026},
|
| 140 |
+
eprint = {2603.23367},
|
| 141 |
+
archivePrefix = {arXiv},
|
| 142 |
+
primaryClass = {cond-mat.mtrl-sci},
|
| 143 |
+
doi = {10.48550/arXiv.2603.23367},
|
| 144 |
+
url = {https://arxiv.org/abs/2603.23367}
|
| 145 |
+
}
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## License
|
| 149 |
+
|
| 150 |
+
BSD 3-Clause — Copyright 2026 UChicago Argonne, LLC.
|
| 151 |
+
|
| 152 |
+
## Links
|
| 153 |
+
|
| 154 |
+
- [arXiv: 2603.23367](https://arxiv.org/abs/2603.23367)
|
| 155 |
+
- [GitHub: OpenAlphaDiffract](https://github.com/AdvancedPhotonSource/AlphaDiffract)
|
config.json
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "alphadiffract",
|
| 3 |
+
"architecture": "ConvNeXt1D-MultiTask",
|
| 4 |
+
"backbone": {
|
| 5 |
+
"dim_in": 8192,
|
| 6 |
+
"channels": [80, 80, 80],
|
| 7 |
+
"kernel_sizes": [100, 50, 25],
|
| 8 |
+
"strides": [5, 5, 5],
|
| 9 |
+
"dropout_rate": 0.3,
|
| 10 |
+
"ramped_dropout_rate": false,
|
| 11 |
+
"block_type": "convnext",
|
| 12 |
+
"pooling_type": "average",
|
| 13 |
+
"final_pool": true,
|
| 14 |
+
"use_batchnorm": false,
|
| 15 |
+
"activation": "gelu",
|
| 16 |
+
"output_type": "flatten",
|
| 17 |
+
"layer_scale_init_value": 0.0,
|
| 18 |
+
"drop_path_rate": 0.3
|
| 19 |
+
},
|
| 20 |
+
"heads": {
|
| 21 |
+
"head_dropout": 0.5,
|
| 22 |
+
"cs_hidden": [2300, 1150],
|
| 23 |
+
"sg_hidden": [2300, 1150],
|
| 24 |
+
"lp_hidden": [512, 256]
|
| 25 |
+
},
|
| 26 |
+
"tasks": {
|
| 27 |
+
"num_cs_classes": 7,
|
| 28 |
+
"num_sg_classes": 230,
|
| 29 |
+
"num_lp_outputs": 6,
|
| 30 |
+
"lp_bounds_min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
| 31 |
+
"lp_bounds_max": [500.0, 500.0, 500.0, 180.0, 180.0, 180.0],
|
| 32 |
+
"bound_lp_with_sigmoid": true
|
| 33 |
+
},
|
| 34 |
+
"training": {
|
| 35 |
+
"optimizer": "AdamW",
|
| 36 |
+
"lr": 0.0002,
|
| 37 |
+
"weight_decay": 0.01,
|
| 38 |
+
"scheduler": "CyclicLR",
|
| 39 |
+
"scheduler_mode": "triangular2",
|
| 40 |
+
"batch_size": 64,
|
| 41 |
+
"max_epochs": 100,
|
| 42 |
+
"precision": "float32",
|
| 43 |
+
"gemd_mu": 1.0,
|
| 44 |
+
"lambda_cs": 1.0,
|
| 45 |
+
"lambda_sg": 1.0,
|
| 46 |
+
"lambda_lp": 1.0
|
| 47 |
+
},
|
| 48 |
+
"preprocessing": {
|
| 49 |
+
"input_length": 8192,
|
| 50 |
+
"floor_at_zero": true,
|
| 51 |
+
"normalize_range": [0.0, 100.0],
|
| 52 |
+
"noise_poisson_range": [1.0, 100.0],
|
| 53 |
+
"noise_gaussian_range": [0.001, 0.1]
|
| 54 |
+
}
|
| 55 |
+
}
|
example_inference.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example: load AlphaDiffract and run inference on a PXRD pattern.
|
| 3 |
+
|
| 4 |
+
Requirements:
|
| 5 |
+
pip install torch safetensors numpy
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from model import AlphaDiffract
|
| 11 |
+
|
| 12 |
+
# 1. Load model ---------------------------------------------------------------
|
| 13 |
+
model = AlphaDiffract.from_pretrained(".", device="cpu") # or "cuda"
|
| 14 |
+
|
| 15 |
+
# 2. Prepare input -------------------------------------------------------------
|
| 16 |
+
# The model expects an 8192-point PXRD intensity pattern normalized to [0, 100].
|
| 17 |
+
# Replace this with your own data.
|
| 18 |
+
pattern = np.random.rand(8192).astype(np.float32) # placeholder
|
| 19 |
+
|
| 20 |
+
# Normalize to [0, 100]
|
| 21 |
+
pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-10) * 100.0
|
| 22 |
+
x = torch.from_numpy(pattern).unsqueeze(0) # shape: (1, 8192)
|
| 23 |
+
|
| 24 |
+
# 3. Inference -----------------------------------------------------------------
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
out = model(x)
|
| 27 |
+
|
| 28 |
+
cs_probs = torch.softmax(out["cs_logits"], dim=-1)
|
| 29 |
+
sg_probs = torch.softmax(out["sg_logits"], dim=-1)
|
| 30 |
+
lp = out["lp"]
|
| 31 |
+
|
| 32 |
+
# 4. Results -------------------------------------------------------------------
|
| 33 |
+
cs_idx = cs_probs.argmax(dim=-1).item()
|
| 34 |
+
sg_idx = sg_probs.argmax(dim=-1).item()
|
| 35 |
+
|
| 36 |
+
print(f"Crystal system : {AlphaDiffract.CRYSTAL_SYSTEMS[cs_idx]} "
|
| 37 |
+
f"({cs_probs[0, cs_idx]:.1%})")
|
| 38 |
+
print(f"Space group : #{sg_idx + 1} ({sg_probs[0, sg_idx]:.1%})")
|
| 39 |
+
|
| 40 |
+
labels = ["a", "b", "c", "alpha", "beta", "gamma"]
|
| 41 |
+
units = ["A", "A", "A", "deg", "deg", "deg"]
|
| 42 |
+
print("Lattice params :")
|
| 43 |
+
for name, val, unit in zip(labels, lp[0].tolist(), units):
|
| 44 |
+
print(f" {name:>5s} = {val:8.3f} {unit}")
|
maxsub.json
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"1": [1],
|
| 3 |
+
"2": [1,2],
|
| 4 |
+
"3": [1,3,4,5],
|
| 5 |
+
"4": [1,4],
|
| 6 |
+
"5": [1,3,4,5],
|
| 7 |
+
"6": [1,6,7,8],
|
| 8 |
+
"7": [1,7,9],
|
| 9 |
+
"8": [1,6,7,8,9],
|
| 10 |
+
"9": [1,7,9],
|
| 11 |
+
"10": [2,3,6,10,11,12,13],
|
| 12 |
+
"11": [2,4,6,11,14],
|
| 13 |
+
"12": [2,5,8,10,11,12,13,14,15],
|
| 14 |
+
"13": [2,3,7,13,14,15],
|
| 15 |
+
"14": [2,4,7,14],
|
| 16 |
+
"15": [2,5,9,13,14,15],
|
| 17 |
+
"16": [3,16,17,21,22],
|
| 18 |
+
"17": [3,4,17,18,20],
|
| 19 |
+
"18": [3,4,18,19],
|
| 20 |
+
"19": [4,19],
|
| 21 |
+
"20": [4,5,17,18,19,20],
|
| 22 |
+
"21": [3,5,16,17,18,20,21,23,24],
|
| 23 |
+
"22": [5,20,21,22],
|
| 24 |
+
"23": [5,16,18,23],
|
| 25 |
+
"24": [5,17,19,24],
|
| 26 |
+
"25": [3,6,25,26,27,28,35,38,39,42],
|
| 27 |
+
"26": [4,6,7,26,29,31,36],
|
| 28 |
+
"27": [3,7,27,30,37],
|
| 29 |
+
"28": [3,6,7,28,29,30,31,32,40,41],
|
| 30 |
+
"29": [4,7,29,33],
|
| 31 |
+
"30": [3,7,30,34],
|
| 32 |
+
"31": [4,6,7,31,33],
|
| 33 |
+
"32": [3,7,32,33,34],
|
| 34 |
+
"33": [4,7,33],
|
| 35 |
+
"34": [3,7,34,43],
|
| 36 |
+
"35": [3,8,25,28,32,35,36,37,44,45,46],
|
| 37 |
+
"36": [4,8,9,26,29,31,33,36],
|
| 38 |
+
"37": [3,9,27,30,34,37],
|
| 39 |
+
"38": [5,6,8,25,26,30,31,38,40,44,46],
|
| 40 |
+
"39": [5,7,8,26,27,28,29,39,41,45,46],
|
| 41 |
+
"40": [5,6,9,28,31,33,34,40],
|
| 42 |
+
"41": [5,7,9,29,30,32,33,41],
|
| 43 |
+
"42": [5,8,35,36,37,38,39,40,41,42],
|
| 44 |
+
"43": [5,9,43],
|
| 45 |
+
"44": [5,8,25,31,34,44],
|
| 46 |
+
"45": [5,9,27,29,32,45],
|
| 47 |
+
"46": [5,8,9,26,28,30,33,46],
|
| 48 |
+
"47": [10,16,25,47,49,51,65,67,69],
|
| 49 |
+
"48": [13,16,34,48,70],
|
| 50 |
+
"49": [10,13,16,27,28,49,50,53,54,66,68],
|
| 51 |
+
"50": [13,16,30,32,48,50,52],
|
| 52 |
+
"51": [10,11,13,17,25,26,28,51,53,54,55,57,59,63,64],
|
| 53 |
+
"52": [13,14,17,30,33,34,52],
|
| 54 |
+
"53": [10,13,14,17,28,30,31,52,53,58,60],
|
| 55 |
+
"54": [13,14,17,27,29,32,52,54,56,60],
|
| 56 |
+
"55": [10,14,18,26,32,55,58,62],
|
| 57 |
+
"56": [13,14,18,27,33,56],
|
| 58 |
+
"57": [11,13,14,18,26,28,29,57,60,61,62],
|
| 59 |
+
"58": [10,14,18,31,34,58],
|
| 60 |
+
"59": [11,13,18,25,31,56,59,62],
|
| 61 |
+
"60": [13,14,18,29,30,33,60],
|
| 62 |
+
"61": [14,19,29,61],
|
| 63 |
+
"62": [11,14,19,26,31,33,62],
|
| 64 |
+
"63": [11,12,15,20,36,38,40,51,52,57,58,59,60,62,63],
|
| 65 |
+
"64": [12,14,15,20,36,39,41,53,54,55,56,57,60,61,62,64],
|
| 66 |
+
"65": [10,12,21,35,38,47,50,51,53,55,59,63,65,66,71,72,74],
|
| 67 |
+
"66": [10,15,21,37,40,48,49,52,53,56,58,66],
|
| 68 |
+
"67": [12,13,21,35,39,49,51,54,57,64,67,68,72,73,74],
|
| 69 |
+
"68": [13,15,21,37,41,50,52,54,60,68],
|
| 70 |
+
"69": [12,22,42,63,64,65,66,67,68,69],
|
| 71 |
+
"70": [15,22,43,70],
|
| 72 |
+
"71": [12,23,44,47,48,58,59,71],
|
| 73 |
+
"72": [12,15,23,45,46,49,50,55,56,57,60,72],
|
| 74 |
+
"73": [15,24,45,54,61,73],
|
| 75 |
+
"74": [12,15,24,44,46,51,52,53,62,74],
|
| 76 |
+
"75": [3,75,77,79],
|
| 77 |
+
"76": [4,76,78],
|
| 78 |
+
"77": [3,76,77,78,80],
|
| 79 |
+
"78": [4,76,78],
|
| 80 |
+
"79": [5,75,77,79],
|
| 81 |
+
"80": [5,76,78,80],
|
| 82 |
+
"81": [3,81,82],
|
| 83 |
+
"82": [5,81,82],
|
| 84 |
+
"83": [10,75,81,83,84,85,87],
|
| 85 |
+
"84": [10,77,81,84,86],
|
| 86 |
+
"85": [13,75,81,85,86],
|
| 87 |
+
"86": [13,77,81,86,88],
|
| 88 |
+
"87": [12,79,82,83,84,85,86,87],
|
| 89 |
+
"88": [15,80,82,88],
|
| 90 |
+
"89": [16,21,75,89,90,93,97],
|
| 91 |
+
"90": [18,21,75,90,94],
|
| 92 |
+
"91": [17,20,76,91,92,95],
|
| 93 |
+
"92": [19,20,76,92,96],
|
| 94 |
+
"93": [16,21,77,91,93,94,95,98],
|
| 95 |
+
"94": [18,21,77,92,94,96],
|
| 96 |
+
"95": [17,20,78,91,95,96],
|
| 97 |
+
"96": [19,20,78,92,96],
|
| 98 |
+
"97": [22,23,79,89,90,93,94,97],
|
| 99 |
+
"98": [22,24,80,91,92,95,96,98],
|
| 100 |
+
"99": [25,35,75,99,100,101,103,105,107,108],
|
| 101 |
+
"100": [32,35,75,100,102,104,106],
|
| 102 |
+
"101": [27,35,77,101,105,106],
|
| 103 |
+
"102": [34,35,77,102,109,110],
|
| 104 |
+
"103": [27,37,75,103,104],
|
| 105 |
+
"104": [34,37,75,104],
|
| 106 |
+
"105": [25,37,77,101,102,105],
|
| 107 |
+
"106": [32,37,77,106],
|
| 108 |
+
"107": [42,44,79,99,102,104,105,107],
|
| 109 |
+
"108": [42,45,79,100,101,103,106,108],
|
| 110 |
+
"109": [43,44,80,109],
|
| 111 |
+
"110": [43,45,80,110],
|
| 112 |
+
"111": [16,35,81,111,112,115,117,119,120],
|
| 113 |
+
"112": [16,37,81,112,116,118],
|
| 114 |
+
"113": [18,35,81,113,114],
|
| 115 |
+
"114": [18,37,81,114],
|
| 116 |
+
"115": [21,25,81,111,113,115,116,121],
|
| 117 |
+
"116": [21,27,81,112,114,116],
|
| 118 |
+
"117": [21,32,81,117,118],
|
| 119 |
+
"118": [21,34,81,118,122],
|
| 120 |
+
"119": [22,44,82,115,118,119],
|
| 121 |
+
"120": [22,45,82,116,117,120],
|
| 122 |
+
"121": [23,42,82,111,112,113,114,121],
|
| 123 |
+
"122": [24,43,82,122],
|
| 124 |
+
"123": [47,65,83,89,99,111,115,123,124,125,127,129,131,132,139,140],
|
| 125 |
+
"124": [49,66,83,89,103,112,116,124,126,128,130],
|
| 126 |
+
"125": [50,67,85,89,100,111,117,125,126,133,134],
|
| 127 |
+
"126": [48,68,85,89,104,112,118,126],
|
| 128 |
+
"127": [55,65,83,90,100,113,117,127,128,135,136],
|
| 129 |
+
"128": [58,66,83,90,104,114,118,128],
|
| 130 |
+
"129": [59,67,85,90,99,113,115,129,130,137,138],
|
| 131 |
+
"130": [56,68,85,90,103,114,116,130],
|
| 132 |
+
"131": [47,66,84,93,105,112,115,131,132,134,136,138],
|
| 133 |
+
"132": [49,65,84,93,101,111,116,131,132,133,135,137],
|
| 134 |
+
"133": [50,68,86,93,106,112,117,133],
|
| 135 |
+
"134": [48,67,86,93,102,111,118,134,141,142],
|
| 136 |
+
"135": [55,66,84,94,106,114,117,135],
|
| 137 |
+
"136": [58,65,84,94,102,113,118,136],
|
| 138 |
+
"137": [59,68,86,94,105,114,115,137],
|
| 139 |
+
"138": [56,67,86,94,101,113,116,138],
|
| 140 |
+
"139": [69,71,87,97,107,119,121,123,126,128,129,131,134,136,137,139],
|
| 141 |
+
"140": [69,72,87,97,108,120,121,124,125,127,130,132,133,135,138,140],
|
| 142 |
+
"141": [70,74,88,98,109,119,122,141],
|
| 143 |
+
"142": [70,73,88,98,110,120,122,142],
|
| 144 |
+
"143": [1,143,144,145,146],
|
| 145 |
+
"144": [1,144,145],
|
| 146 |
+
"145": [1,144,145],
|
| 147 |
+
"146": [1,143,144,145,146],
|
| 148 |
+
"147": [2,143,147,148],
|
| 149 |
+
"148": [2,146,147,148],
|
| 150 |
+
"149": [5,143,149,150,151,153,155],
|
| 151 |
+
"150": [5,143,149,150,152,154],
|
| 152 |
+
"151": [5,144,151,152,153],
|
| 153 |
+
"152": [5,144,151,152,154],
|
| 154 |
+
"153": [5,145,151,153,154],
|
| 155 |
+
"154": [5,145,152,153,154],
|
| 156 |
+
"155": [5,146,150,152,154,155],
|
| 157 |
+
"156": [8,143,156,157,158],
|
| 158 |
+
"157": [8,143,156,157,159,160],
|
| 159 |
+
"158": [9,143,158,159],
|
| 160 |
+
"159": [9,143,158,159,161],
|
| 161 |
+
"160": [8,146,156,160,161],
|
| 162 |
+
"161": [9,146,158,161],
|
| 163 |
+
"162": [12,147,149,157,162,163,164,166],
|
| 164 |
+
"163": [15,147,149,159,163,165,167],
|
| 165 |
+
"164": [12,147,150,156,162,164,165],
|
| 166 |
+
"165": [15,147,150,158,163,165],
|
| 167 |
+
"166": [12,148,155,160,164,166,167],
|
| 168 |
+
"167": [15,148,155,161,165,167],
|
| 169 |
+
"168": [3,143,168,171,172,173],
|
| 170 |
+
"169": [4,144,169,170],
|
| 171 |
+
"170": [4,145,169,170],
|
| 172 |
+
"171": [3,145,169,171,172],
|
| 173 |
+
"172": [3,144,170,171,172],
|
| 174 |
+
"173": [4,143,169,170,173],
|
| 175 |
+
"174": [6,143,174],
|
| 176 |
+
"175": [10,147,168,174,175,176],
|
| 177 |
+
"176": [11,147,173,174,176],
|
| 178 |
+
"177": [21,149,150,168,177,180,181,182],
|
| 179 |
+
"178": [20,151,152,169,178,179],
|
| 180 |
+
"179": [20,153,154,170,178,179],
|
| 181 |
+
"180": [21,153,154,171,178,180,181],
|
| 182 |
+
"181": [21,151,152,172,179,180,181],
|
| 183 |
+
"182": [20,149,150,173,178,179,182],
|
| 184 |
+
"183": [35,156,157,168,183,184,185,186],
|
| 185 |
+
"184": [37,158,159,168,184],
|
| 186 |
+
"185": [36,157,158,173,185,186],
|
| 187 |
+
"186": [36,156,159,173,185,186],
|
| 188 |
+
"187": [38,149,156,174,187,188,189],
|
| 189 |
+
"188": [40,149,158,174,188,190],
|
| 190 |
+
"189": [38,150,157,174,187,189,190],
|
| 191 |
+
"190": [40,150,159,174,188,190],
|
| 192 |
+
"191": [65,162,164,175,177,183,187,189,191,192,193,194],
|
| 193 |
+
"192": [66,163,165,175,177,184,188,190,192],
|
| 194 |
+
"193": [63,162,165,176,182,185,188,189,193,194],
|
| 195 |
+
"194": [63,163,164,176,182,186,187,190,193,194],
|
| 196 |
+
"195": [16,146,196,197,199],
|
| 197 |
+
"196": [22,146,195,198],
|
| 198 |
+
"197": [23,146,195],
|
| 199 |
+
"198": [19,146],
|
| 200 |
+
"199": [24,146,198],
|
| 201 |
+
"200": [47,148,195,202,204,206],
|
| 202 |
+
"201": [48,148,195,203],
|
| 203 |
+
"202": [69,148,196,200,201,205],
|
| 204 |
+
"203": [70,148,196],
|
| 205 |
+
"204": [71,148,197,200,201],
|
| 206 |
+
"205": [61,148,198],
|
| 207 |
+
"206": [73,148,199,205],
|
| 208 |
+
"207": [89,155,195,209,211],
|
| 209 |
+
"208": [93,155,195,210,214],
|
| 210 |
+
"209": [97,155,196,207,208],
|
| 211 |
+
"210": [98,155,196,212,213],
|
| 212 |
+
"211": [97,155,197,207,208],
|
| 213 |
+
"212": [96,155,198],
|
| 214 |
+
"213": [92,155,198],
|
| 215 |
+
"214": [98,155,199,212,213],
|
| 216 |
+
"215": [111,160,195,216,217,219],
|
| 217 |
+
"216": [119,160,196,215],
|
| 218 |
+
"217": [121,160,197,215,218],
|
| 219 |
+
"218": [112,161,195,220],
|
| 220 |
+
"219": [120,161,196,218],
|
| 221 |
+
"220": [122,161,199],
|
| 222 |
+
"221": [123,166,200,207,215,225,226,229],
|
| 223 |
+
"222": [126,167,201,207,218],
|
| 224 |
+
"223": [131,167,200,208,218,230],
|
| 225 |
+
"224": [134,166,201,208,215,227,228],
|
| 226 |
+
"225": [139,166,202,209,216,221,224],
|
| 227 |
+
"226": [140,167,202,209,219,222,223],
|
| 228 |
+
"227": [141,166,203,210,216],
|
| 229 |
+
"228": [142,167,203,210,219],
|
| 230 |
+
"229": [139,166,204,211,217,221,222,223,224],
|
| 231 |
+
"230": [142,167,206,214,220]
|
| 232 |
+
}
|
model.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is self-contained: download it alongside `model.safetensors`,
|
| 3 |
+
`config.json`, and `maxsub.json` to load and run the model.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from collections import deque
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Utility: DropPath (Stochastic Depth)
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
def drop_path(
|
| 20 |
+
x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
if drop_prob == 0.0 or not training:
|
| 23 |
+
return x
|
| 24 |
+
keep_prob = 1 - drop_prob
|
| 25 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 26 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 27 |
+
random_tensor = random_tensor.floor()
|
| 28 |
+
return x.div(keep_prob) * random_tensor
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DropPath(nn.Module):
|
| 32 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.drop_prob = drop_prob
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# ConvNeXt 1D Block
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
class ConvNeXtBlock1D(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
dim: int,
|
| 47 |
+
kernel_size: int,
|
| 48 |
+
drop_path: float,
|
| 49 |
+
layer_scale_init_value: float,
|
| 50 |
+
activation: nn.Module,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.dwconv = nn.Conv1d(
|
| 54 |
+
dim, dim, kernel_size=kernel_size, padding="same", groups=dim
|
| 55 |
+
)
|
| 56 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim)
|
| 57 |
+
self.act = activation() if isinstance(activation, type) else activation
|
| 58 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 59 |
+
self.gamma = (
|
| 60 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim))
|
| 61 |
+
if layer_scale_init_value > 0
|
| 62 |
+
else None
|
| 63 |
+
)
|
| 64 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
shortcut = x
|
| 68 |
+
x = self.dwconv(x)
|
| 69 |
+
x = x.permute(0, 2, 1)
|
| 70 |
+
x = self.pwconv1(x)
|
| 71 |
+
x = self.act(x)
|
| 72 |
+
x = self.pwconv2(x)
|
| 73 |
+
if self.gamma is not None:
|
| 74 |
+
x = x * self.gamma
|
| 75 |
+
x = x.permute(0, 2, 1)
|
| 76 |
+
x = shortcut + self.drop_path(x)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ConvNextBlock1DAdaptor(nn.Module):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
in_channels: int,
|
| 84 |
+
out_channels: int,
|
| 85 |
+
kernel_size: int,
|
| 86 |
+
stride: int,
|
| 87 |
+
dropout: float,
|
| 88 |
+
use_batchnorm: bool,
|
| 89 |
+
activation: nn.Module,
|
| 90 |
+
layer_scale_init_value: float,
|
| 91 |
+
drop_path_rate: float,
|
| 92 |
+
block_type: str,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
if in_channels != out_channels:
|
| 96 |
+
act = activation() if isinstance(activation, type) else activation
|
| 97 |
+
self.pwconv = nn.Sequential(nn.Linear(in_channels, out_channels), act)
|
| 98 |
+
else:
|
| 99 |
+
self.pwconv = None
|
| 100 |
+
|
| 101 |
+
if block_type == "convnext":
|
| 102 |
+
self.block = ConvNeXtBlock1D(
|
| 103 |
+
dim=out_channels,
|
| 104 |
+
kernel_size=kernel_size,
|
| 105 |
+
drop_path=drop_path_rate,
|
| 106 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 107 |
+
activation=activation,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
self.block = None
|
| 111 |
+
|
| 112 |
+
if stride > 1:
|
| 113 |
+
self.reduction_pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
| 114 |
+
else:
|
| 115 |
+
self.reduction_pool = None
|
| 116 |
+
|
| 117 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 118 |
+
if self.pwconv is not None:
|
| 119 |
+
x = x.permute(0, 2, 1)
|
| 120 |
+
x = self.pwconv(x)
|
| 121 |
+
x = x.permute(0, 2, 1)
|
| 122 |
+
if self.block is not None:
|
| 123 |
+
x = self.block(x)
|
| 124 |
+
if self.reduction_pool is not None:
|
| 125 |
+
x = self.reduction_pool(x)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# MLP head builder
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
def make_mlp(
|
| 133 |
+
input_dim: int,
|
| 134 |
+
hidden_dims: Optional[Tuple[int, ...]],
|
| 135 |
+
output_dim: int,
|
| 136 |
+
dropout: float = 0.2,
|
| 137 |
+
output_activation: Optional[nn.Module] = None,
|
| 138 |
+
) -> nn.Module:
|
| 139 |
+
layers: List[nn.Module] = []
|
| 140 |
+
last = input_dim
|
| 141 |
+
if hidden_dims is not None and len(hidden_dims) > 0:
|
| 142 |
+
for hd in hidden_dims:
|
| 143 |
+
layers.extend([nn.Linear(last, hd), nn.ReLU()])
|
| 144 |
+
if dropout and dropout > 0:
|
| 145 |
+
layers.append(nn.Dropout(dropout))
|
| 146 |
+
last = hd
|
| 147 |
+
layers.append(nn.Linear(last, output_dim))
|
| 148 |
+
if output_activation is not None:
|
| 149 |
+
layers.append(output_activation)
|
| 150 |
+
return nn.Sequential(*layers)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
# Backbone
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
class MultiscaleCNNBackbone1D(nn.Module):
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
dim_in: int,
|
| 160 |
+
channels: Tuple[int, ...],
|
| 161 |
+
kernel_sizes: Tuple[int, ...],
|
| 162 |
+
strides: Tuple[int, ...],
|
| 163 |
+
dropout_rate: float,
|
| 164 |
+
ramped_dropout_rate: bool,
|
| 165 |
+
block_type: str,
|
| 166 |
+
pooling_type: str,
|
| 167 |
+
final_pool: bool,
|
| 168 |
+
use_batchnorm: bool,
|
| 169 |
+
activation: nn.Module,
|
| 170 |
+
output_type: str,
|
| 171 |
+
layer_scale_init_value: float,
|
| 172 |
+
drop_path_rate: float,
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
assert len(channels) == len(kernel_sizes) == len(strides)
|
| 176 |
+
self.dim_in = dim_in
|
| 177 |
+
self.output_type = output_type
|
| 178 |
+
|
| 179 |
+
if ramped_dropout_rate:
|
| 180 |
+
dropout_per_stage = torch.linspace(
|
| 181 |
+
0.0, dropout_rate, steps=len(channels)
|
| 182 |
+
).tolist()
|
| 183 |
+
else:
|
| 184 |
+
dropout_per_stage = [dropout_rate] * len(channels)
|
| 185 |
+
|
| 186 |
+
if pooling_type == "average":
|
| 187 |
+
pool_cls = nn.AvgPool1d
|
| 188 |
+
pool_kwargs = {"kernel_size": 3, "stride": 2}
|
| 189 |
+
elif pooling_type == "max":
|
| 190 |
+
pool_cls = nn.MaxPool1d
|
| 191 |
+
pool_kwargs = {"kernel_size": 2, "stride": 2}
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f"Invalid pooling_type '{pooling_type}'")
|
| 194 |
+
|
| 195 |
+
layers: List[nn.Module] = []
|
| 196 |
+
in_ch = 1
|
| 197 |
+
for i, (out_ch, k, s) in enumerate(zip(channels, kernel_sizes, strides)):
|
| 198 |
+
stage_block = ConvNextBlock1DAdaptor(
|
| 199 |
+
in_channels=in_ch,
|
| 200 |
+
out_channels=out_ch,
|
| 201 |
+
kernel_size=k,
|
| 202 |
+
stride=s,
|
| 203 |
+
dropout=dropout_per_stage[i],
|
| 204 |
+
use_batchnorm=use_batchnorm,
|
| 205 |
+
activation=activation,
|
| 206 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 207 |
+
drop_path_rate=drop_path_rate,
|
| 208 |
+
block_type=block_type,
|
| 209 |
+
)
|
| 210 |
+
layers.append(stage_block)
|
| 211 |
+
if i < len(channels) - 1 or final_pool:
|
| 212 |
+
layers.append(pool_cls(**pool_kwargs))
|
| 213 |
+
in_ch = out_ch
|
| 214 |
+
|
| 215 |
+
self.net = nn.Sequential(*layers)
|
| 216 |
+
|
| 217 |
+
if self.output_type == "gap":
|
| 218 |
+
self.dim_output = channels[-1]
|
| 219 |
+
elif self.output_type == "flatten":
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
dummy = torch.zeros(1, 1, self.dim_in)
|
| 222 |
+
out = self.net(dummy)
|
| 223 |
+
self.dim_output = int(out.shape[1] * out.shape[2])
|
| 224 |
+
else:
|
| 225 |
+
raise ValueError(f"Invalid output_type '{self.output_type}'")
|
| 226 |
+
|
| 227 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 228 |
+
if x.ndim == 2:
|
| 229 |
+
x = x[:, None, :]
|
| 230 |
+
x = self.net(x)
|
| 231 |
+
if self.output_type == "gap":
|
| 232 |
+
x = x.mean(dim=-1)
|
| 233 |
+
else:
|
| 234 |
+
x = x.reshape(x.shape[0], -1)
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
# GEMD distance-matrix utilities
|
| 240 |
+
# ---------------------------------------------------------------------------
|
| 241 |
+
def _build_distance_matrix_from_maxsub_lut(
|
| 242 |
+
maxsub_lut: Dict[str, List[int]],
|
| 243 |
+
num_sg_classes: int,
|
| 244 |
+
) -> torch.Tensor:
|
| 245 |
+
adjacency: List[set] = [set() for _ in range(num_sg_classes)]
|
| 246 |
+
for key, neighbors in maxsub_lut.items():
|
| 247 |
+
src = int(key) - 1
|
| 248 |
+
for raw_dst in neighbors:
|
| 249 |
+
dst = int(raw_dst) - 1
|
| 250 |
+
adjacency[src].add(dst)
|
| 251 |
+
adjacency[dst].add(src)
|
| 252 |
+
|
| 253 |
+
distance_matrix = torch.zeros(
|
| 254 |
+
(num_sg_classes, num_sg_classes), dtype=torch.float32
|
| 255 |
+
)
|
| 256 |
+
for src in range(num_sg_classes):
|
| 257 |
+
dists = [-1] * num_sg_classes
|
| 258 |
+
dists[src] = 0
|
| 259 |
+
queue = deque([src])
|
| 260 |
+
while queue:
|
| 261 |
+
cur = queue.popleft()
|
| 262 |
+
for nxt in adjacency[cur]:
|
| 263 |
+
if dists[nxt] == -1:
|
| 264 |
+
dists[nxt] = dists[cur] + 1
|
| 265 |
+
queue.append(nxt)
|
| 266 |
+
distance_matrix[src] = torch.tensor(dists, dtype=torch.float32)
|
| 267 |
+
return distance_matrix
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def load_gemd_distance_matrix(
|
| 271 |
+
path: str, num_sg_classes: int = 230
|
| 272 |
+
) -> torch.Tensor:
|
| 273 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 274 |
+
payload: Any = json.load(f)
|
| 275 |
+
if isinstance(payload, dict) and all(str(k).isdigit() for k in payload.keys()):
|
| 276 |
+
return _build_distance_matrix_from_maxsub_lut(payload, num_sg_classes)
|
| 277 |
+
elif isinstance(payload, list):
|
| 278 |
+
return torch.as_tensor(payload, dtype=torch.float32)
|
| 279 |
+
raise ValueError(f"Could not parse GEMD data from {path}")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ---------------------------------------------------------------------------
|
| 283 |
+
# Full model
|
| 284 |
+
# ---------------------------------------------------------------------------
|
| 285 |
+
class AlphaDiffract(nn.Module):
|
| 286 |
+
"""
|
| 287 |
+
AlphaDiffract: multi-task 1D ConvNeXt for powder X-ray diffraction
|
| 288 |
+
pattern analysis.
|
| 289 |
+
|
| 290 |
+
Predicts crystal system (7 classes), space group (230 classes), and
|
| 291 |
+
lattice parameters (6 values: a, b, c, alpha, beta, gamma).
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
CRYSTAL_SYSTEMS = [
|
| 295 |
+
"Triclinic",
|
| 296 |
+
"Monoclinic",
|
| 297 |
+
"Orthorhombic",
|
| 298 |
+
"Tetragonal",
|
| 299 |
+
"Trigonal",
|
| 300 |
+
"Hexagonal",
|
| 301 |
+
"Cubic",
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
def __init__(self, config: dict, maxsub_path: Optional[str] = None):
|
| 305 |
+
super().__init__()
|
| 306 |
+
bb = config["backbone"]
|
| 307 |
+
heads = config["heads"]
|
| 308 |
+
tasks = config["tasks"]
|
| 309 |
+
|
| 310 |
+
activation = nn.GELU
|
| 311 |
+
|
| 312 |
+
self.backbone = MultiscaleCNNBackbone1D(
|
| 313 |
+
dim_in=bb["dim_in"],
|
| 314 |
+
channels=tuple(bb["channels"]),
|
| 315 |
+
kernel_sizes=tuple(bb["kernel_sizes"]),
|
| 316 |
+
strides=tuple(bb["strides"]),
|
| 317 |
+
dropout_rate=bb["dropout_rate"],
|
| 318 |
+
ramped_dropout_rate=bb["ramped_dropout_rate"],
|
| 319 |
+
block_type=bb["block_type"],
|
| 320 |
+
pooling_type=bb["pooling_type"],
|
| 321 |
+
final_pool=bb["final_pool"],
|
| 322 |
+
use_batchnorm=bb["use_batchnorm"],
|
| 323 |
+
activation=activation,
|
| 324 |
+
output_type=bb["output_type"],
|
| 325 |
+
layer_scale_init_value=bb["layer_scale_init_value"],
|
| 326 |
+
drop_path_rate=bb["drop_path_rate"],
|
| 327 |
+
)
|
| 328 |
+
feat_dim = self.backbone.dim_output
|
| 329 |
+
|
| 330 |
+
self.cs_head = make_mlp(
|
| 331 |
+
feat_dim, tuple(heads["cs_hidden"]), tasks["num_cs_classes"],
|
| 332 |
+
dropout=heads["head_dropout"],
|
| 333 |
+
)
|
| 334 |
+
self.sg_head = make_mlp(
|
| 335 |
+
feat_dim, tuple(heads["sg_hidden"]), tasks["num_sg_classes"],
|
| 336 |
+
dropout=heads["head_dropout"],
|
| 337 |
+
)
|
| 338 |
+
self.lp_head = make_mlp(
|
| 339 |
+
feat_dim, tuple(heads["lp_hidden"]), tasks["num_lp_outputs"],
|
| 340 |
+
dropout=heads["head_dropout"],
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
self.bound_lp_with_sigmoid = tasks["bound_lp_with_sigmoid"]
|
| 344 |
+
self.register_buffer(
|
| 345 |
+
"lp_min",
|
| 346 |
+
torch.tensor(tasks["lp_bounds_min"], dtype=torch.float32),
|
| 347 |
+
)
|
| 348 |
+
self.register_buffer(
|
| 349 |
+
"lp_max",
|
| 350 |
+
torch.tensor(tasks["lp_bounds_max"], dtype=torch.float32),
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if maxsub_path is not None:
|
| 354 |
+
gemd = load_gemd_distance_matrix(maxsub_path)
|
| 355 |
+
self.register_buffer("gemd_distance_matrix", gemd)
|
| 356 |
+
else:
|
| 357 |
+
self.gemd_distance_matrix = None
|
| 358 |
+
|
| 359 |
+
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 360 |
+
"""
|
| 361 |
+
Args:
|
| 362 |
+
x: PXRD pattern tensor of shape ``(batch, 8192)`` or
|
| 363 |
+
``(batch, 1, 8192)``, intensity-normalized to [0, 100].
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Dict with keys ``cs_logits``, ``sg_logits``, ``lp``.
|
| 367 |
+
"""
|
| 368 |
+
feats = self.backbone(x)
|
| 369 |
+
cs_logits = self.cs_head(feats)
|
| 370 |
+
sg_logits = self.sg_head(feats)
|
| 371 |
+
lp = self.lp_head(feats)
|
| 372 |
+
if self.bound_lp_with_sigmoid:
|
| 373 |
+
lp = torch.sigmoid(lp) * (self.lp_max - self.lp_min) + self.lp_min
|
| 374 |
+
return {"cs_logits": cs_logits, "sg_logits": sg_logits, "lp": lp}
|
| 375 |
+
|
| 376 |
+
# -- convenience loaders ------------------------------------------------
|
| 377 |
+
|
| 378 |
+
@classmethod
|
| 379 |
+
def from_pretrained(
|
| 380 |
+
cls,
|
| 381 |
+
model_dir: str,
|
| 382 |
+
device: str = "cpu",
|
| 383 |
+
) -> "AlphaDiffract":
|
| 384 |
+
"""Load model from a directory containing config.json,
|
| 385 |
+
model.safetensors, and maxsub.json."""
|
| 386 |
+
model_dir = Path(model_dir)
|
| 387 |
+
with open(model_dir / "config.json", "r") as f:
|
| 388 |
+
config = json.load(f)
|
| 389 |
+
|
| 390 |
+
maxsub_path = model_dir / "maxsub.json"
|
| 391 |
+
model = cls(
|
| 392 |
+
config,
|
| 393 |
+
maxsub_path=str(maxsub_path) if maxsub_path.exists() else None,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
weights_path = model_dir / "model.safetensors"
|
| 397 |
+
if weights_path.exists():
|
| 398 |
+
from safetensors.torch import load_file
|
| 399 |
+
state_dict = load_file(str(weights_path), device=device)
|
| 400 |
+
else:
|
| 401 |
+
# Fallback to PyTorch format
|
| 402 |
+
pt_path = model_dir / "model.pt"
|
| 403 |
+
state_dict = torch.load(str(pt_path), map_location=device, weights_only=True)
|
| 404 |
+
|
| 405 |
+
model.load_state_dict(state_dict)
|
| 406 |
+
model.to(device)
|
| 407 |
+
model.eval()
|
| 408 |
+
return model
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d81f2716ef8c78683d52ac51afff3eaf160c0b2b410685d0c90299bc2fd58ed
|
| 3 |
+
size 35155332
|