File size: 1,707 Bytes
15ce941 744de0d 15ce941 869a97c 15ce941 744de0d 869a97c 744de0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from ssllm_hf import SSLLMForCausalLM, SSLLMConfig
import tiktoken
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# Initialize model with config
config = SSLLMConfig.from_pretrained('sausheong/ssllm_hf')
model = SSLLMForCausalLM(config)
# Download and load model weights
model_path = hf_hub_download(repo_id='sausheong/ssllm_hf', filename='model.safetensors')
state_dict = load_file(model_path)
model.load_state_dict(state_dict, strict=False)
# Setup device and eval mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).eval()
# Initialize tokenizer
tokenizer = tiktoken.get_encoding('cl100k_base')
def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40):
# Encode the prompt
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
attention_mask = torch.ones_like(input_ids)
# Generate with the model
with torch.no_grad():
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
pad_token_id=100257,
eos_token_id=100257,
)
# Decode only the new tokens
new_tokens = outputs[0][input_ids.shape[1]:].tolist()
generated = tokenizer.decode(new_tokens)
print(f"{prompt}{generated}")
print(f"\nTokens generated: {len(new_tokens)}")
if __name__ == "__main__":
prompt = "In a small village nestled between mountains,"
print(f"PROMPT: {prompt}\n--")
generate_text(prompt)
|