import torch | |
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM | |
model_path = "Crystalcareai/GemMoE-Medium-v0.4" | |
# Load model | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float16, | |
attn_implementation="flash_attention_2" | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Convert prompt to tokens | |
prompt_template = "[INST] {prompt} [/INST]" | |
prompt = "You're standing on the surface of the Earth. "\ | |
"You walk one mile south, one mile west and one mile north. "\ | |
"You end up exactly where you started. Where are you?" | |
tokens = tokenizer( | |
prompt_template.format(prompt=prompt), | |
return_tensors='pt' | |
).input_ids.cuda() | |
# Generate output | |
generation_output = model.generate( | |
tokens, | |
streamer=streamer, | |
max_new_tokens=512 | |
) |