Trained on 554m tokens, 1 epoch, lr .00987 brown corpus quotes (wikiquote, azquote, gracious quotes, english quotes) idioms (scraped) defitions (wordnet) wiki_text mini pile
Trained on runpod for 5 days using 3090
code: https://gist.github.com/thistleknot/368ab298edf596ef50d2cfdcbec66fd1
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Specify the path to the directory where the model is stored
#model_dir = r"C:\Users\User\Documents\wiki\wiki\data science\nlp\research\mamba_brown_trained_556m\mamba_brown_trained\mamba_brown_trained"
model_dir = "/home/user/mamba_brown_trained"
# Load the tokenizer from the local directory
# Load the tokenizer and model (use a causal language model for text generation)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.to('cuda')
# Now, you can use the model and tokenizer for inference
input_text = "Once upon a time"
# Tokenize the input
inputs = tokenizer(input_text, return_tensors="pt").to('cuda')
# Generate output tokens using the model
output_ids = model.generate(**inputs, max_length=50)
# Decode the generated token IDs back into text
decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Print the generated output text
print(decoded_output)
Once upon a time, the world is changing.
# Now, you can use the model and tokenizer for inference
input_text = "The Fulton County Grand Fair was set for Friday at"
inputs = tokenizer(input_text, return_tensors="pt").to('cuda')
# Generate output tokens using the model with repetition controls
output_ids = model.generate(
**inputs,
max_length=256, # Max tokens to generate
repetition_penalty=1.2, # Penalize repeated words
no_repeat_ngram_size=3, # Prevent 3-gram repetitions
temperature=0.9, # Adjust randomness (lower means more deterministic)
top_k=50, # Only sample from top 50 tokens
top_p=0.9 # Use nucleus sampling to control diversity
)
# Decode the generated token IDs back into text
decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Print the generated output text
print(decoded_output)
- Downloads last month
- 3
Model tree for LaferriereJC/jamba_550M_trained
Base model
ai21labs/AI21-Jamba-1.5-Mini