|
import torch |
|
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
from pathlib import Path |
|
from config import get_config, get_weights_file_path |
|
from train import get_model |
|
|
|
def get_tokenizer(config)->Tokenizer: |
|
tokenizers_path = Path(config['tokenizer_file']) |
|
if Path.exists(tokenizers_path): |
|
print("Loading tokenizer from ", tokenizers_path) |
|
tokenizer = Tokenizer.from_file(str(tokenizers_path)) |
|
return tokenizer |
|
else: |
|
raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path) |
|
|
|
|
|
config = get_config("./openweb.config.json") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
tokenizer = get_tokenizer(config) |
|
pad_token_id = tokenizer.token_to_id("<pad>") |
|
eos_token_id = tokenizer.token_to_id("</s>") |
|
user_token_id = tokenizer.token_to_id("<user>") |
|
ai_token_id = tokenizer.token_to_id("<ai>") |
|
|
|
model = get_model(config, tokenizer.get_vocab_size()).to(device) |
|
model_path = get_weights_file_path(config,config['preload']) |
|
model.eval() |
|
state = torch.load(model_path,map_location=torch.device('cpu')) |
|
model.load_state_dict(state['model_state_dict']) |
|
|
|
def generate_response(prompt:str): |
|
print("Prompt : ",prompt) |
|
|
|
word = "" |
|
input_tokens = tokenizer.encode(prompt).ids |
|
input_tokens.extend([user_token_id] + input_tokens + [ai_token_id] ) |
|
if len(input_tokens) > config['seq_len']: |
|
print(f"exceeding max length of input : {config['seq_len']}") |
|
exit() |
|
input_tokens = torch.tensor(input_tokens) |
|
decoder_input = input_tokens.to(device) |
|
if decoder_input.dim() == 1: |
|
decoder_input = decoder_input.unsqueeze(0) |
|
temperature = 0.7 |
|
top_k = 50 |
|
i = 0 |
|
print("Output : ",end="") |
|
while decoder_input.shape[1] < 2000: |
|
|
|
|
|
|
|
out = model.decode(decoder_input) |
|
logits = model.project(out[:, -1]) |
|
logits = logits / temperature |
|
top_k_logits, top_k_indices = torch.topk(logits, top_k) |
|
probs = torch.softmax(top_k_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
next_token = top_k_indices.gather(-1, next_token) |
|
word += tokenizer.decode([next_token.item()]) |
|
print(word,end="") |
|
i+=1 |
|
decoder_input = torch.cat([decoder_input, next_token], dim=1) |
|
if decoder_input.shape[1] > config['seq_len']: |
|
decoder_input = decoder_input[:,-config['seq_len']:] |
|
if next_token.item() == eos_token_id or i >= 1024: |
|
break |
|
print() |
|
return word |