AstroBERT Small: Domain-specialized small models
astro-ph and Wikipedia articles labeled as astronomy related.
These domain-specialized small models often perform as good as models 10-100x larger. It demonstrates that narrowing down a model to a small domain requires less overall parameters than models generalized for all problems.
The following new models are released as part of this effort. All models have an Apache 2.0 license.
| Model | Description |
|---|---|
| AstroBERT Small | Base 22.7M parameter language model |
| AstroBERT Small Embeddings | Small Sentence Transformers model for embeddings |
Building a Strong Baseline
The first step was to build a 22.7M parameter BERT encoder-only model was trained on ArXiv abstracts categorized as astro-ph and Wikipedia articles labeled as astronomy related.
The model was trained using masked language modeling with the following code.
import csv
from datasets import concatenate_datasets, load_dataset, load_from_disk
from transformers import AutoTokenizer
from transformers import BertConfig, BertForMaskedLM
from txtai.pipeline import HFTrainer
def loadids():
labels, count = set(), 0
with open("/data/sources/wikipedia/labels/labels.csv", mode="r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
if row["label"] == domain:
labels.add(row["id"])
count += 1
return labels, count
# Load target domain
domain = "astronomy"
uids, total = loadids()
# Filter by domain labels
dataset = load_dataset("neuml/wikipedia-20260401", split="train")
dataset = dataset.filter(lambda x: x["title"] in uids)
dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"])
# Add arxiv abstracts
dataset = concatenate_datasets([dataset, load_from_disk(f"datasets/arxiv-{domain}")])
# Calculate number of epochs based on size of filtered dataset
epochs = 3 * int(total / len(dataset))
print(f"Calculated {epochs} epochs")
# Standard tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained(f"tokenizers/{domain}")
# Configuration - bert small
config = BertConfig(
hidden_size=384,
num_hidden_layers=6,
num_attention_heads=6,
intermediate_size=1536
)
# Model to train
model = BertForMaskedLM(config)
print(config)
print("Total parameters:", sum(p.numel() for p in model.bert.parameters()))
train = HFTrainer()
#
# Train using MLM
#
# Settings copied from original BERT training - override when HF Trainer defaults don't match
# - BERT Paper (pg. 13): https://arxiv.org/pdf/1810.04805
# - BERT Tiny Paper: https://arxiv.org/pdf/1908.08962
# - BERT Parameters: https://github.com/google-research/bert/blob/master/optimization.py#L59
# - HF Trainer defaults: https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments
train((model, tokenizer), dataset, task="language-modeling", output_dir="output",
fp16=True, learning_rate=1e-3, per_device_train_batch_size=64, num_train_epochs=epochs,
warmup_steps=2500, weight_decay=0.01, adam_epsilon=1e-6,
tokenizers=True, dataloader_num_workers=20,
save_strategy="steps", save_steps=5000, logging_steps=500,
)
The model is intended to be further fine-tuned for a specific task such as Text Classification, Entity Extraction, Sentence Embeddings and so on.
Training a Small Embeddings model
Next a sentence-transformers model was fined-tuned to generate vector embeddings. The training dataset was generated using a random sample of ArXiv abstracts labeled as astro-ph.
The model was trained by distilling embeddings from the larger Qwen3-Embedding-8B model using EmbedDistillLoss over the generated training dataset.
As noted in the paper Well-Read Students Learn Better: On the Importance of Pre-training Compact Models, it's important that the base model is pretrained on a large corpus of relevant documents prior to distillation.
The training code is shown below.
import json
import logging
import numpy as np
from datasets import Dataset, load_from_disk, concatenate_datasets
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
models,
losses
)
logging.basicConfig(level=logging.INFO)
def load(path):
rows = []
with open(path, "r", encoding="utf-8") as inputs:
for line in inputs:
row = json.loads(line)
rows.append({
"query": row[0],
"document": row[1],
})
return rows
def run(domain):
# Training prompts
prompts = {
"query": "query: ",
"document": "document: "
}
# Embeddings model
embeddings = models.Transformer("astrobert-small")
# Pooling model
pooling = models.Pooling(embeddings.get_embedding_dimension())
# Create sentence-transformers model
model = SentenceTransformer(modules=[embeddings, pooling], prompts=prompts)
# Teacher model
teacher = SentenceTransformer("Qwen/Qwen3-Embedding-8B")
# Load training data
train = load(f"training/{domain}-similarity-train.jsonl") + load(f"training/{domain}-similarity-train-questions.jsonl")
train = Dataset.from_list(train)
def compute(batch):
embed1 = teacher.encode(batch["query"], prompt_name="query", show_progress_bar=False, batch_size=8)
embed2 = teacher.encode(batch["document"], prompt_name="document", show_progress_bar=False, batch_size=8)
return {"label": np.stack([embed1, embed2], axis=1).tolist()}
# Build separate shards
shards = 10
for x in range(shards):
shard = train.shard(num_shards=shards, index=x)
shard = shard.map(
compute,
batched=True,
batch_size=1_000,
writer_batch_size=1_000,
new_fingerprint=f"embeddings-{x}"
)
shard.save_to_disk(f"training/{domain}-embeddings/shard-{x}")
train = concatenate_datasets([
load_from_disk(f"training/{domain}-embeddings/shard-{x}")
for x in range(shards)
])
path = f"{domain}bert-embeddings"
args = SentenceTransformerTrainingArguments(
output_dir=path,
num_train_epochs=25,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
gradient_accumulation_steps=2,
fp16=True,
learning_rate=3e-4,
save_steps=0,
logging_steps=500,
dataloader_num_workers=20,
prompts={
"query": prompts["query"],
"document": prompts["document"],
}
)
# Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train,
loss=losses.EmbedDistillLoss(model, distance_metric="cosine", projection_dim=4096),
)
trainer.train()
model.save(path)
# Train model
run("astronomy")
Evaluation Results
A BEIR-compatible dataset was generated to facilitate the evaluation process. This is a separate random sample of ArXiv abstracts alongside generated user queries.
Evaluation results are shown below. NDCG is used as the evaluation metric.
| Model | Parameters | NDCG | Index Time | Search Time | Disk |
|---|---|---|---|---|---|
| AstroBERT Small Embeddings | 22.7M | 69.09 | 9.9s | 0.42s | 16 MB |
| all-MiniLM-L6-v2 | 22.7M | 40.45 | 12.50s | 0.38s | 16 MB |
| DenseOn | 149M | 61.46 | 67.35s | 0.77s | 31 MB |
| EmbeddingGemma | 300M | 57.44 | 86.17s | 1.43s | 31 MB |
| Qwen3-Embedding-0.6B | 600M | 65.73 | 114.17s | 2.20s | 41 MB |
| Qwen3-Embedding-4B | 4000M | 71.14 | 545.28s | 9.89s | 103 MB |
| Qwen3-Embedding-8B | 8000M | 73.84 | 941.82s | 17.24s | 164 MB |
This model is a solid performer at a small size. It beats the same sized all-MiniLM-L6-v2 model by a significant margin. It beats the 600M parameter Qwen3 Embeddings model which is over 25x larger. It scores slightly lower than the model it's distilled from (Qwen3-Embedding-8B).
This is a great model that can be used in CPU-only setups without trading off much on the accuracy front. It shows how small models can excel at specialized domains, requiring less compute and disk space.
Wrapping up
This article introduced the new AstroBERT Small series of models. It demonstrates that narrowing down a model to a small domain requires less overall parameters than models generalized for all problems.
If you're interested in building custom models like this for your data or domain area, feel free to reach out!
NeuML is the company behind txtai and we provide AI consulting services around our stack. Schedule a meeting or send a message to learn more.
We're also building an easy and secure way to run hosted txtai applications with txtai.cloud.

