metadata
datasets:
- mlfoundations/dclm-baseline-1.0
base_model:
- HuggingFaceTB/SmolLM3-3B-Base
Essence 3B V1.1
This is a system using two versions of SmolLM3-3B-Base, the 'encoder', is finetuned to turn a text into a set of embedding tokens which can be reconstituted back into the original text by the decoder. In addition to “vanilla” reconstruction, this model was trained for span-corruption and masked language modelling.
We use LoRA at rank 64 on QKVO along with trainable LayerNorms and, for the encoder, LoRA on all MLP layers as well as trainable token embeddings.
The model was trained to encode text into any of 1-128 embedding tokens.
Simple Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import nn
import torch
from huggingface_hub import hf_hub_download
device = torch.device("cuda:0")
dtype = torch.bfloat16
base_model_id = "HuggingFaceTB/SmolLM3-3B-Base"
compressor_id = "midwestern-simulation/essence-3b-v1.1"
# === MODEL LOADING ===
tokenizer = AutoTokenizer.from_pretrained(base_model_id, padding_side='left')
encoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)
decoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)
encoder = PeftModel.from_pretrained(encoder, compressor_id, subfolder="encoder")
decoder = PeftModel.from_pretrained(decoder, compressor_id, subfolder="decoder")
projector = nn.Linear(2048, 2048).to(device).to(dtype)
projector.load_state_dict(torch.load(hf_hub_download(repo_id=compressor_id, filename="projector.pt")))
# === MODEL INFERENCE ===
text = "mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow"
n_embed_tokens = 4 # for best performance, can be any within the range of 1-128
encoder_input = text.strip() + f"\n[[/END DOCUMENT]]\n[[START SUMMARY ntoks={n_embed_tokens}]]" + "<|im_end|>" * n_embed_tokens
tokenized = tokenizer(encoder_input, return_tensors='pt', add_special_tokens=False)
tokenized = {k: v.to(device) for k, v in tokenized.items()}
encoding = encoder.model.model(**tokenized).last_hidden_state[:, -n_embed_tokens:, :]
encoding = projector(encoding)
tokenized_prefix = tokenizer("\n[[/END SUMMARY]]\n[[START DOCUMENT]]\n", return_tensors="pt", add_special_tokens=False)
prefix_embeds = decoder.model.model.embed_tokens(tokenized_prefix['input_ids'].to(device))
inputs_embeds = torch.cat([encoding, prefix_embeds], 1)
output = decoder.generate(
inputs_embeds=inputs_embeds,
temperature=0.7,
max_new_tokens=1024,
do_sample=True,
top_k=128,
min_new_tokens=8,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
print(tokenizer.decode(output[0]))
# mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow
# [[/END DOCUMENT]]<|end_of_text|>