Safetensors
essence-3b-v1.1 / README.md
metalure's picture
Update README.md
6b4d420 verified
metadata
datasets:
  - mlfoundations/dclm-baseline-1.0
base_model:
  - HuggingFaceTB/SmolLM3-3B-Base

Essence 3B V1.1

This is a system using two versions of SmolLM3-3B-Base, the 'encoder', is finetuned to turn a text into a set of embedding tokens which can be reconstituted back into the original text by the decoder. In addition to “vanilla” reconstruction, this model was trained for span-corruption and masked language modelling.

We use LoRA at rank 64 on QKVO along with trainable LayerNorms and, for the encoder, LoRA on all MLP layers as well as trainable token embeddings.

The model was trained to encode text into any of 1-128 embedding tokens.

Simple Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import nn
import torch
from huggingface_hub import hf_hub_download

device = torch.device("cuda:0")
dtype = torch.bfloat16
base_model_id = "HuggingFaceTB/SmolLM3-3B-Base"
compressor_id = "midwestern-simulation/essence-3b-v1.1"

# === MODEL LOADING ===

tokenizer = AutoTokenizer.from_pretrained(base_model_id, padding_side='left')
encoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)
decoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)

encoder = PeftModel.from_pretrained(encoder, compressor_id, subfolder="encoder")
decoder = PeftModel.from_pretrained(decoder, compressor_id, subfolder="decoder")

projector = nn.Linear(2048, 2048).to(device).to(dtype)
projector.load_state_dict(torch.load(hf_hub_download(repo_id=compressor_id, filename="projector.pt")))


# === MODEL INFERENCE ===

text = "mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow"
n_embed_tokens = 4 # for best performance, can be any within the range of 1-128

encoder_input = text.strip() + f"\n[[/END DOCUMENT]]\n[[START SUMMARY ntoks={n_embed_tokens}]]" + "<|im_end|>" * n_embed_tokens

tokenized = tokenizer(encoder_input, return_tensors='pt', add_special_tokens=False)
tokenized = {k: v.to(device) for k, v in tokenized.items()}
encoding = encoder.model.model(**tokenized).last_hidden_state[:, -n_embed_tokens:, :]
encoding = projector(encoding)

tokenized_prefix = tokenizer("\n[[/END SUMMARY]]\n[[START DOCUMENT]]\n", return_tensors="pt", add_special_tokens=False)
prefix_embeds = decoder.model.model.embed_tokens(tokenized_prefix['input_ids'].to(device))
inputs_embeds = torch.cat([encoding, prefix_embeds], 1)
output = decoder.generate(
    inputs_embeds=inputs_embeds,
    temperature=0.7,
    max_new_tokens=1024,
    do_sample=True,
    top_k=128,
    min_new_tokens=8,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id
)
print(tokenizer.decode(output[0]))
# mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow
# [[/END DOCUMENT]]<|end_of_text|>