10M-LLM / inference_fine_tune.py
abancp's picture
Update inference_fine_tune.py
dd1b76c verified
raw
history blame contribute delete
2.74 kB
import torch
from tokenizers import Tokenizer
from pathlib import Path
from config import get_config, get_weights_file_path
from train import get_model
def get_tokenizer(config)->Tokenizer:
tokenizers_path = Path(config['tokenizer_file'])
if Path.exists(tokenizers_path):
print("Loading tokenizer from ", tokenizers_path)
tokenizer = Tokenizer.from_file(str(tokenizers_path))
return tokenizer
else:
raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path)
config = get_config("./openweb.config.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer(config)
pad_token_id = tokenizer.token_to_id("<pad>")
eos_token_id = tokenizer.token_to_id("</s>")
user_token_id = tokenizer.token_to_id("<user>")
ai_token_id = tokenizer.token_to_id("<ai>")
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config,config['preload'])
model.eval()
state = torch.load(model_path,map_location=torch.device('cpu'))
model.load_state_dict(state['model_state_dict'])
def generate_response(prompt:str):
print("Prompt : ",prompt)
word = ""
input_tokens = tokenizer.encode(prompt).ids
input_tokens.extend([user_token_id] + input_tokens + [ai_token_id] )
if len(input_tokens) > config['seq_len']:
print(f"exceeding max length of input : {config['seq_len']}")
exit()
input_tokens = torch.tensor(input_tokens)
decoder_input = input_tokens.to(device)
if decoder_input.dim() == 1:
decoder_input = decoder_input.unsqueeze(0)
temperature = 0.7
top_k = 50
i = 0
print("Output : ",end="")
while decoder_input.shape[1] < 2000:
# Apply causal mask based on current decoder_input length
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
# Get model output
out = model.decode(decoder_input)
logits = model.project(out[:, -1]) # Get logits for last token
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
word += tokenizer.decode([next_token.item()])
print(word,end="")
i+=1
decoder_input = torch.cat([decoder_input, next_token], dim=1)
if decoder_input.shape[1] > config['seq_len']:
decoder_input = decoder_input[:,-config['seq_len']:]
if next_token.item() == eos_token_id or i >= 1024:
break
print()
return word