license: apache-2.0
language:
- en
pipeline_tag: text-classification
inference: false
Monarch Mixer-BERT
The 80M checkpoint for M2-BERT-base from the paper Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture. This model has been pretrained with sequence length 2048, and it has been fine-tuned for long-context retrieval.
Check out our blog post for more on how we trained this model for long sequence.
This model was trained by Jon Saad-Falcon, Dan Fu, and Simran Arora.
Check out our GitHub for instructions on how to download and fine-tune it!
How to use
You can load this model using Hugging Face AutoModel
:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
"togethercomputer/m2-bert-80M-2k-retrieval",
trust_remote_code=True
)
You should expect to see a large error message about unused parameters for FlashFFTConv. If you'd like to load the model with FlashFFTConv, you can check out our GitHub.
This model generates embeddings for retrieval. The embeddings have a dimensionality of 768:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
max_seq_length = 2048
testing_string = "Every morning, I make a cup of coffee to start my day."
model = AutoModelForSequenceClassification.from_pretrained(
"togethercomputer/m2-bert-80M-2k-retrieval",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
model_max_length=max_seq_length
)
input_ids = tokenizer(
[testing_string],
return_tensors="pt",
padding="max_length",
return_token_type_ids=False,
truncation=True,
max_length=max_seq_length
)
outputs = model(**input_ids)
embeddings = outputs['sentence_embedding']
You can also get embeddings from this model using the Together API as follows (you can find your API key here):
import os
import requests
def generate_together_embeddings(text: str, model_api_string: str, api_key: str):
url = "https://api.together.xyz/api/v1/embeddings"
headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {api_key}"
}
session = requests.Session()
response = session.post(
url,
headers=headers,
json={
"input": text,
"model": model_api_string
}
)
if response.status_code != 200:
raise ValueError(f"Request failed with status code {response.status_code}: {response.text}")
return response.json()['data'][0]['embedding']
print(generate_together_embeddings(
'Hello world',
'togethercomputer/m2-bert-80M-2k-retrieval',
os.environ['TOGETHER_API_KEY'])[:10]
)