n0w0f commited on
Commit
283f249
Β·
verified Β·
1 Parent(s): 7949a14

Update README for v2: NL queries, 1024 ctx, LaCLIP architecture

Browse files
Files changed (1) hide show
  1. README.md +158 -138
README.md CHANGED
@@ -1,111 +1,139 @@
1
- # MatText Aligned Embeddings: Multi-Modal Material Retrieval
2
 
3
- **A CLIP-style multi-modal embedding model that aligns 10 different material text representations into a shared 128-d vector space for cross-modal retrieval.**
4
 
5
- Query with *any* modality (composition, CIF, SLICES, natural language, z-matrix...) β†’ retrieve materials with similar properties across *all* modalities.
 
 
 
 
 
 
 
 
 
 
6
 
7
  ## πŸ—οΈ Architecture
8
 
9
  ```
10
- β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
11
- β”‚ MatTextEncoder β”‚
12
- β”‚ β”‚
13
- β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
14
- β”‚ β”‚ Shared Backbone: ModernBERT-base (150M params) β”‚ β”‚
15
- β”‚ β”‚ - 8192 token context window (handles long CIFs) β”‚ β”‚
16
- β”‚ β”‚ - Mean pooling β†’ 768-d representation β”‚ β”‚
17
- β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
18
- β”‚ β”‚ β”‚
19
- β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
20
- β”‚ β–Ό β–Ό β–Ό β”‚
21
- β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
22
- β”‚ β”‚ Projection β”‚ β”‚ Projection β”‚ β”‚ Projection β”‚ ... β”‚
23
- β”‚ β”‚ composition β”‚ β”‚ cif_sym β”‚ β”‚ slices β”‚ β”‚
24
- β”‚ β”‚ 768β†’768β†’128 β”‚ β”‚ 768β†’768β†’128 β”‚ β”‚ 768β†’768β†’128 β”‚ β”‚
25
- β”‚ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
26
- β”‚ β–Ό β–Ό β–Ό β”‚
27
- β”‚ 128-d L2-norm 128-d L2-norm 128-d L2-norm β”‚
28
- β”‚ β”‚
29
- β”‚ ──── Shared Embedding Space ──── β”‚
30
- β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
 
 
31
  ```
32
 
33
- ### Key Design Decisions
34
-
35
- | Decision | Choice | Rationale |
36
- |----------|--------|-----------|
37
- | Backbone | ModernBERT-base | 8192 ctx handles long CIFs; fast RoPE attention |
38
- | Projection | 2-layer MLP per modality | MultiMat recipe: modality-specific heads preserve specialization |
39
- | Embedding dim | 128 | Standard for contrastive learning; compact for FAISS |
40
- | Loss | AllPairsCLIP + Property-MSE | Aligns all N(N-1)/2 modality pairs; property regularization |
41
- | Temperature | Learnable (init 0.07) | CLIP standard; learned Ο„ improves convergence |
42
-
43
- ## πŸ“Š Modalities Supported
44
-
45
- | Modality | Column | Example | Query Type |
46
- |----------|--------|---------|------------|
47
- | Composition | `composition` | `Fe2O3` | "Find iron oxides" |
48
- | Atom Sequence | `atom_sequences` | `Fe Fe Fe O O O` | Element lists |
49
- | CIF (symmetrized) | `cif_symmetrized` | Full CIF text | Paste CIF data |
50
- | CIF (P1) | `cif_p1` | Full CIF in P1 | Paste CIF data |
51
- | Z-matrix | `zmatrix` | `Fe\nO 1 2.0\nO 1 2.0 2 90` | Internal coords |
52
- | Atom Seq++ | `atom_sequences_plusplus` | `Fe O 3.57 3.57 90 90` | Elements + lattice |
53
- | SLICES | `slices` | `Fe O 0 1 o o o` | SLICES encoding |
54
- | Crystal Text (LLM) | `crystal_text_llm` | `3.6 3.6 3.6\n90 90 90\nFe...` | Gruver format |
55
- | Local Environment | `local_env` | SMILES-like env | Local bonding |
56
- | Natural Language | `robocrys_rep` | "FeO crystallizes in..." | Plain English |
57
- | **Property Query** | property text | "bandgap: 1.5 eV" | Property search |
58
 
59
- ## πŸ§ͺ Training Recipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- Based on three key papers:
 
 
 
 
 
62
 
63
- 1. **MultiMat** (AllPairsCLIP, [arxiv:2312.00111](https://arxiv.org/abs/2312.00111)): Sum of symmetric InfoNCE over all modality pairs
64
- 2. **MatExpert** ([arxiv:2410.21317](https://arxiv.org/abs/2410.21317)): Property↔structure contrastive alignment
65
- 3. **CrystalCLR** ([arxiv:2211.13408](https://arxiv.org/abs/2211.13408)): Composition similarity loss
66
- 4. **SupReMix** ([arxiv:2309.16633](https://arxiv.org/abs/2309.16633)): Property-label-aware soft contrastive
 
67
 
68
  ### Two-Phase Training
69
 
70
- **Phase 1 β€” Multi-modal alignment** (pretrain100k_v2, 50k samples):
71
- - AllPairsCLIP loss across all 10 modalities
72
- - Random modality sampling (4/10 per step) for VRAM efficiency
73
- - Each step aligns C(4,2)=6 modality pairs
74
 
75
- **Phase 2 β€” Property-conditioned alignment** (bandgap + form_energy, 50k samples):
76
- - Same CLIP loss + property similarity MSE loss
77
- - Property text "composition: Fe2O3 | bandgap: 2.1000" aligned with structure representations
78
- - Materials with similar property values cluster in embedding space
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  ### Hyperparameters
81
 
82
- ```
83
  encoder: answerdotai/ModernBERT-base
84
  embed_dim: 128
85
- max_length: 512 tokens
86
- batch_size: 32 Γ— 8 grad_accum = 256 effective
87
- learning_rate: 2e-5 (cosine decay, 10% warmup)
88
  temperature: learnable (init 0.07)
89
  epochs: 3 per phase
90
  optimizer: AdamW (weight_decay=0.01)
91
- fp16: True
92
  gradient_checkpointing: True
 
93
  ```
94
 
95
  ## πŸš€ Quick Start
96
 
97
- ### Training
98
 
99
  ```bash
100
- pip install torch transformers datasets faiss-cpu huggingface_hub trackio
101
 
102
- # Local GPU
103
- python train_mattext_embeddings.py
104
 
105
- # HF Jobs (recommended: a10g-large, 24GB VRAM)
106
- # Set timeout to 6h
107
  ```
108
 
 
 
 
 
 
109
  ### Inference & Search
110
 
111
  ```python
@@ -113,119 +141,111 @@ import torch
113
  import faiss
114
  import json
115
  import numpy as np
116
- from transformers import AutoModel, AutoTokenizer
117
-
118
- # Load model
119
  from train_mattext_embeddings import MatTextEncoder, Config, search_vector_db
120
 
 
121
  config = Config()
122
  config.device = "cuda" if torch.cuda.is_available() else "cpu"
123
-
124
  model = MatTextEncoder(config)
125
  model.load_state_dict(torch.load("mattext-embeddings/model.pt", map_location=config.device))
126
- model = model.to(config.device)
127
- model.eval()
128
-
129
  tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
130
 
131
  # Load FAISS indices
132
  indices = {}
133
- for mod in ["composition", "crystal_text_llm", "slices", "cif_symmetrized"]:
134
  index = faiss.read_index(f"mattext-embeddings/faiss/{mod}.index")
135
  with open(f"mattext-embeddings/faiss/{mod}_metadata.json") as f:
136
  metadata = json.load(f)
137
  indices[mod] = {"index": index, "metadata": metadata}
138
-
139
- # Search!
140
- results = search_vector_db("Fe2O3", "composition", model, tokenizer, indices, config, k=5)
141
- for score, meta in results:
142
- print(f"Score: {score:.4f} | {meta['composition']}")
143
  ```
144
 
145
- ### Cross-Modal Query Examples
146
 
147
  ```python
148
- # Query by composition β†’ find across all modalities
149
- search_vector_db("SiO2", "composition", model, tokenizer, indices, config)
 
 
 
150
 
151
- # Query by natural language β†’ find materials
152
- search_vector_db("perovskite with high bandgap", "robocrys_rep", model, tokenizer, indices, config)
 
153
 
154
- # Query by SLICES representation
155
- search_vector_db("Si O 0 1 o o o", "slices", model, tokenizer, indices, config)
156
 
157
- # Query by CIF data
158
- search_vector_db("data_SiO2\n_symmetry P1\n...", "cif_symmetrized", model, tokenizer, indices, config)
159
 
160
- # Property-conditioned query
161
- search_vector_db("composition: Si | bandgap: 1.1200", "property", model, tokenizer, indices, config)
 
 
 
162
  ```
163
 
164
- ## πŸ”¬ Evaluation Metrics
 
 
 
 
 
 
 
 
 
 
165
 
166
- Cross-modal Recall@k: for each material, embed in modality A, retrieve in modality B, check if correct match is in top-k.
167
 
168
- | Pair | R@1 | R@5 | R@10 |
169
- |------|-----|-----|------|
170
- | composition β†’ crystal_text_llm | TBD | TBD | TBD |
171
- | composition β†’ cif_symmetrized | TBD | TBD | TBD |
172
- | slices β†’ crystal_text_llm | TBD | TBD | TBD |
173
- | robocrys_rep β†’ composition | TBD | TBD | TBD |
174
 
175
  *Results populated after training.*
176
 
177
  ## 🧩 Extending: Graph Embeddings
178
 
179
- The architecture supports adding graph neural network (GNN) embeddings:
180
 
181
  ```python
182
- # Add a GNN projection head
183
- from torch_geometric.nn import SchNet, DimeNet # or CGCNN
184
 
185
  class GraphEncoder(nn.Module):
186
  def __init__(self, embed_dim=128):
187
  super().__init__()
188
- self.gnn = SchNet(hidden_channels=256, num_filters=128, num_interactions=6)
189
  self.proj = ModalityProjection(256, embed_dim)
190
 
191
  def forward(self, data):
192
- # data: PyG Data with pos, z (atomic numbers), batch
193
  h = self.gnn(data.z, data.pos, data.batch)
194
  return self.proj(h)
195
 
196
- # Add to MatTextEncoder:
197
- model.graph_encoder = GraphEncoder(config.embed_dim)
198
- model.projections["graph"] = model.graph_encoder.proj
199
-
200
- # Training: treat graph embeddings as another modality in AllPairsCLIP
201
  ```
202
 
203
- For graph embeddings, convert CIF β†’ PyG Data (using `pymatgen` + `torch_geometric`):
204
- ```python
205
- from pymatgen.core import Structure
206
- from torch_geometric.data import Data
207
-
208
- def cif_to_graph(cif_string, cutoff=5.0):
209
- struct = Structure.from_str(cif_string, fmt="cif")
210
- # Get neighbors within cutoff
211
- neighbors = struct.get_all_neighbors(cutoff)
212
- # Build edge_index, pos, z ...
213
- return Data(z=atomic_numbers, pos=positions, edge_index=edge_index)
214
- ```
215
 
216
  ## πŸ“š References
217
 
218
- - **MatText**: [arxiv:2406.17295](https://arxiv.org/abs/2406.17295) β€” Dataset and text representations
219
- - **MultiMat**: [arxiv:2312.00111](https://arxiv.org/abs/2312.00111) β€” AllPairsCLIP for materials
220
- - **MatExpert**: [arxiv:2410.21317](https://arxiv.org/abs/2410.21317) β€” Property↔structure alignment
221
- - **CrystalCLR**: [arxiv:2211.13408](https://arxiv.org/abs/2211.13408) β€” Contrastive learning for crystals
222
- - **SupReMix**: [arxiv:2309.16633](https://arxiv.org/abs/2309.16633) β€” Property-aware hard negatives
223
- - **Symile**: [arxiv:2411.01053](https://arxiv.org/abs/2411.01053) β€” Total-correlation loss for M modalities
 
224
 
225
  ## πŸ“„ License
226
 
227
  MIT
228
-
229
- ## πŸ”— Dataset
230
-
231
- [n0w0f/MatText](https://huggingface.co/datasets/n0w0f/MatText) β€” 100k+ crystal structures in 10 text representations
 
1
+ # MatText Aligned Embeddings v2: Multi-Modal Material Retrieval with Natural Language Queries
2
 
3
+ **A CLIP-style multi-modal embedding model that aligns 10+ material text representations into a shared 128-d vector space. Query with natural language ("oxide with high bandgap"), composition, CIF, SLICES, or any modality β†’ retrieve matching materials.**
4
 
5
+ ## πŸ†• v2 Key Features
6
+
7
+ | Feature | v1 | v2 |
8
+ |---------|----|----|
9
+ | Context length | 512 tokens | **1024 tokens** (captures long CIFs) |
10
+ | Natural language queries | ❌ | **βœ… "oxide with high bandgap"** |
11
+ | Property-aware retrieval | Basic | **LaCLIP-style diverse NL descriptions** |
12
+ | GPU optimization | fp16 / 24GB | **bf16 / 80GB A100 optimized** |
13
+ | Effective batch size | 256 | **288** |
14
+ | Modalities per step | 4 | **5** |
15
+ | Flash Attention 2 | ❌ | **βœ… (auto-detect)** |
16
 
17
  ## πŸ—οΈ Architecture
18
 
19
  ```
20
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
21
+ β”‚ MatTextEncoder (157M params) β”‚
22
+ β”‚ β”‚
23
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
24
+ β”‚ β”‚ Shared Backbone: ModernBERT-base (150M params, 8192 ctx) β”‚ β”‚
25
+ β”‚ β”‚ Mean pooling β†’ 768-d representation β”‚ β”‚
26
+ β”‚ β”‚ Gradient checkpointing + bf16 β”‚ β”‚
27
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
28
+ β”‚ β”‚ β”‚
29
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
30
+ β”‚ β–Ό β–Ό β–Ό β–Ό β”‚
31
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€οΏ½οΏ½β”€β”€β”€β”€β” β”‚
32
+ β”‚ β”‚comp β”‚ β”‚cif_sym β”‚ β”‚nl_property_desc β”‚ β”‚property β”‚ ...Γ—12 β”‚
33
+ β”‚ β”‚768β†’768 β”‚ β”‚768β†’768 β”‚ β”‚768β†’768β†’128 β”‚ β”‚768β†’768 β”‚ β”‚
34
+ β”‚ β”‚β†’128 β”‚ β”‚β†’128 β”‚ β”‚"oxide with high β”‚ β”‚β†’128 β”‚ β”‚
35
+ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ bandgap" queries β”‚ β”‚ β”‚ β”‚
36
+ β”‚ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β”‚
37
+ β”‚ β–Ό β–Ό β–Ό β–Ό β”‚
38
+ β”‚ 128-d L2 128-d L2 128-d L2 128-d L2 β”‚
39
+ β”‚ β”‚
40
+ β”‚ ──── Shared 128-d Embedding Space ──── β”‚
41
+ β”‚ (FAISS IndexFlatIP for cosine similarity search) β”‚
42
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
43
  ```
44
 
45
+ ### 12 Projection Heads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ | # | Head | Input | Purpose |
48
+ |---|------|-------|---------|
49
+ | 1 | `composition` | "Fe2O3" | Formula queries |
50
+ | 2 | `atom_sequences` | "Fe Fe O O O" | Element list queries |
51
+ | 3 | `cif_symmetrized` | Full CIF | Paste CIF data |
52
+ | 4 | `cif_p1` | CIF in P1 | P1 space group CIF |
53
+ | 5 | `zmatrix` | Z-matrix coords | Internal coordinates |
54
+ | 6 | `atom_sequences_plusplus` | Elements + lattice | Atom sequence + cell |
55
+ | 7 | `slices` | SLICES encoding | Compact structure encoding |
56
+ | 8 | `crystal_text_llm` | Gruver format | Lattice + coords text |
57
+ | 9 | `local_env` | SMILES-like env | Local bonding environment |
58
+ | 10 | `robocrys_rep` | NL description | "FeO crystallizes in..." |
59
+ | 11 | **`nl_property_description`** | **Free-form NL** | **"oxide with high bandgap"** |
60
+ | 12 | `property` | Structured props | "bandgap: 2.1 eV" |
61
+
62
+ ## πŸ” How NL Queries Work
63
+
64
+ The key innovation is a **LaCLIP-style** training approach ([arxiv:2305.20088](https://arxiv.org/abs/2305.20088)):
65
 
66
+ 1. **During Phase 2 training**, for each material with known properties (bandgap, formation energy), we generate **diverse natural language descriptions** from templates:
67
+ - `"A wide bandgap oxide suitable for UV applications, bandgap 3.20 eV"`
68
+ - `"TiO2: oxide semiconductor with wide band gap of 3.20 electron volts"`
69
+ - `"This binary oxide (TiO2) exhibits a wide bandgap of approximately 3.20 eV"`
70
+
71
+ 2. These NL descriptions are passed through a **dedicated `nl_property_description` projection head** and aligned with ALL structure modalities via InfoNCE.
72
 
73
+ 3. **At inference**, when you query `"oxide with high bandgap"`, the model maps it through the same NL head into the shared embedding space, and FAISS finds the nearest materials β€” those that were trained to be close to similar descriptions.
74
+
75
+ This is distinct from `robocrys_rep` (which describes crystal *structure*: "FeO crystallizes in the rock salt structure..."). The NL query head describes *properties* ("wide bandgap oxide").
76
+
77
+ ## πŸ§ͺ Training Recipe
78
 
79
  ### Two-Phase Training
80
 
81
+ **Phase 1 β€” Multi-modal alignment** (pretrain100k_v2, 60k samples, 3 epochs):
82
+ - AllPairsCLIP loss across 10 modalities
83
+ - Random modality sampling (5/10 per step) β€” always includes composition + crystal_text_llm
84
+ - Effective batch 288
85
 
86
+ **Phase 2 β€” Property-conditioned + NL query alignment** (bandgap + formation_energy, 60k samples, 3 epochs):
87
+ - AllPairsCLIP loss (structure modalities)
88
+ - **NL description ↔ structure InfoNCE** (the key NL query loss)
89
+ - Property ↔ composition/crystal_text_llm InfoNCE ([MatExpert](https://arxiv.org/abs/2410.21317))
90
+ - SupReMix-style property similarity MSE ([arxiv:2309.16633](https://arxiv.org/abs/2309.16633))
91
+ - Loss weights: `L = L_clip + 0.3 * L_property + 0.5 * L_nl`
92
+
93
+ ### Based On
94
+
95
+ | Paper | Contribution | ArXiv |
96
+ |-------|-------------|-------|
97
+ | **MultiMat** | AllPairsCLIP loss | [2312.00111](https://arxiv.org/abs/2312.00111) |
98
+ | **MatExpert** | Property↔structure InfoNCE | [2410.21317](https://arxiv.org/abs/2410.21317) |
99
+ | **LaCLIP** | LLM text augmentation for CLIP | [2305.20088](https://arxiv.org/abs/2305.20088) |
100
+ | **SupReMix** | Property-label-aware soft contrastive | [2309.16633](https://arxiv.org/abs/2309.16633) |
101
+ | **CrystalCLR** | Composition similarity | [2211.13408](https://arxiv.org/abs/2211.13408) |
102
 
103
  ### Hyperparameters
104
 
105
+ ```yaml
106
  encoder: answerdotai/ModernBERT-base
107
  embed_dim: 128
108
+ max_length: 1024 tokens
109
+ batch_size: 48 Γ— 6 grad_accum = 288 effective
110
+ learning_rate: 2e-5 (phase 1), 1e-5 (phase 2)
111
  temperature: learnable (init 0.07)
112
  epochs: 3 per phase
113
  optimizer: AdamW (weight_decay=0.01)
114
+ precision: bf16 (A100) / fp16 (T4/V100)
115
  gradient_checkpointing: True
116
+ max_modalities_per_step: 5
117
  ```
118
 
119
  ## πŸš€ Quick Start
120
 
121
+ ### Training (your GPU)
122
 
123
  ```bash
124
+ pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate
125
 
126
+ # Optional but recommended for A100/H100:
127
+ pip install flash-attn --no-build-isolation
128
 
129
+ python train_mattext_embeddings.py
 
130
  ```
131
 
132
+ The script auto-detects:
133
+ - GPU capability (bf16 for Ampere+, fp16 otherwise)
134
+ - Flash Attention 2 availability
135
+ - CUDA vs CPU
136
+
137
  ### Inference & Search
138
 
139
  ```python
 
141
  import faiss
142
  import json
143
  import numpy as np
144
+ from transformers import AutoTokenizer
 
 
145
  from train_mattext_embeddings import MatTextEncoder, Config, search_vector_db
146
 
147
+ # Load
148
  config = Config()
149
  config.device = "cuda" if torch.cuda.is_available() else "cpu"
 
150
  model = MatTextEncoder(config)
151
  model.load_state_dict(torch.load("mattext-embeddings/model.pt", map_location=config.device))
152
+ model = model.to(config.device).eval()
 
 
153
  tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
154
 
155
  # Load FAISS indices
156
  indices = {}
157
+ for mod in ["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"]:
158
  index = faiss.read_index(f"mattext-embeddings/faiss/{mod}.index")
159
  with open(f"mattext-embeddings/faiss/{mod}_metadata.json") as f:
160
  metadata = json.load(f)
161
  indices[mod] = {"index": index, "metadata": metadata}
 
 
 
 
 
162
  ```
163
 
164
+ ### Query Examples
165
 
166
  ```python
167
+ # πŸ” Natural language property queries (THE KEY FEATURE)
168
+ search_vector_db("oxide with high bandgap", "nl_property_description", model, tokenizer, indices, config)
169
+ search_vector_db("stable ternary nitride", "nl_property_description", model, tokenizer, indices, config)
170
+ search_vector_db("narrow bandgap semiconductor for IR", "nl_property_description", model, tokenizer, indices, config)
171
+ search_vector_db("metallic binary compound", "nl_property_description", model, tokenizer, indices, config)
172
 
173
+ # πŸ§ͺ Composition queries
174
+ search_vector_db("Fe2O3", "composition", model, tokenizer, indices, config)
175
+ search_vector_db("BaTiO3", "composition", model, tokenizer, indices, config)
176
 
177
+ # πŸ“– Structure description queries
178
+ search_vector_db("perovskite with octahedral coordination", "robocrys_rep", model, tokenizer, indices, config)
179
 
180
+ # πŸ“Š Structured property queries
181
+ search_vector_db("composition: TiO2 | bandgap: 3.2000", "property", model, tokenizer, indices, config)
182
 
183
+ # πŸ”¬ CIF queries (paste your CIF)
184
+ search_vector_db("data_TiO2\n_symmetry P1\n_cell 4.59 4.59 2.96 90 90 90", "cif_symmetrized", ...)
185
+
186
+ # 🧬 SLICES queries
187
+ search_vector_db("Ti O 0 1 o o o", "slices", model, tokenizer, indices, config)
188
  ```
189
 
190
+ ## πŸ“Š Evaluation Metrics
191
+
192
+ Cross-modal Recall@k on test set:
193
+
194
+ | Pair | R@1 | R@5 | R@10 | R@20 |
195
+ |------|-----|-----|------|------|
196
+ | composition β†’ crystal_text_llm | TBD | TBD | TBD | TBD |
197
+ | composition β†’ cif_symmetrized | TBD | TBD | TBD | TBD |
198
+ | composition β†’ slices | TBD | TBD | TBD | TBD |
199
+ | slices β†’ crystal_text_llm | TBD | TBD | TBD | TBD |
200
+ | robocrys_rep β†’ composition | TBD | TBD | TBD | TBD |
201
 
202
+ NL Query Results:
203
 
204
+ | Query | Top-1 Match | Score |
205
+ |-------|------------|-------|
206
+ | "oxide with high bandgap" | TBD | TBD |
207
+ | "narrow bandgap semiconductor" | TBD | TBD |
208
+ | "stable binary oxide" | TBD | TBD |
 
209
 
210
  *Results populated after training.*
211
 
212
  ## 🧩 Extending: Graph Embeddings
213
 
214
+ The architecture is plug-and-play for new modalities:
215
 
216
  ```python
217
+ # Add a GNN modality
218
+ from torch_geometric.nn import SchNet
219
 
220
  class GraphEncoder(nn.Module):
221
  def __init__(self, embed_dim=128):
222
  super().__init__()
223
+ self.gnn = SchNet(hidden_channels=256)
224
  self.proj = ModalityProjection(256, embed_dim)
225
 
226
  def forward(self, data):
 
227
  h = self.gnn(data.z, data.pos, data.batch)
228
  return self.proj(h)
229
 
230
+ # Register as new modality
231
+ model.projections["graph"] = graph_encoder.proj
232
+ # It gets aligned automatically through AllPairsCLIP
 
 
233
  ```
234
 
235
+ ## πŸ“¦ Dataset
236
+
237
+ [n0w0f/MatText](https://huggingface.co/datasets/n0w0f/MatText) β€” 100k+ crystal structures in 10+ text representations
 
 
 
 
 
 
 
 
 
238
 
239
  ## πŸ“š References
240
 
241
+ - **MatText**: [arxiv:2406.17295](https://arxiv.org/abs/2406.17295)
242
+ - **MultiMat**: [arxiv:2312.00111](https://arxiv.org/abs/2312.00111)
243
+ - **MatExpert**: [arxiv:2410.21317](https://arxiv.org/abs/2410.21317)
244
+ - **LaCLIP**: [arxiv:2305.20088](https://arxiv.org/abs/2305.20088)
245
+ - **SupReMix**: [arxiv:2309.16633](https://arxiv.org/abs/2309.16633)
246
+ - **CrystalCLR**: [arxiv:2211.13408](https://arxiv.org/abs/2211.13408)
247
+ - **Symile**: [arxiv:2411.01053](https://arxiv.org/abs/2411.01053)
248
 
249
  ## πŸ“„ License
250
 
251
  MIT