import os, time, torch, warnings from transformers import GPT2LMHeadModel, GPT2Tokenizer class Inference(): def __init__(self, silent=False) -> None: start_time = time.perf_counter() self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState")) self.model.eval() if not silent: print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") def local_file_path(self, path): return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) def generate(self, prompt, max_length=2000, temperature=0.5, do_sample=True, stop_token=None, callback=None, silent=True): with warnings.catch_warnings(): warnings.simplefilter("ignore") start_time = time.perf_counter() input_ids = self.tokenizer.encode(prompt, return_tensors='pt') generated_text = input_ids while generated_text.shape[1] < max_length: length = min(50, max_length - generated_text.shape[1]) with torch.no_grad(): outputs = self.model.generate(input_ids, max_length=length, temperature=temperature, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id) new_tokens = outputs[0][-length:] if callback is not None: for token in new_tokens: callback(self.tokenizer.decode([token])) generated_text = torch.cat((generated_text, new_tokens.unsqueeze(0)), dim=-1) input_ids = new_tokens.unsqueeze(0) if stop_token is not None and stop_token in self.tokenizer.decode(generated_text[0]): break if not silent: print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") return self.tokenizer.decode(generated_text[0], skip_special_tokens=True) Inference = Inference() def spec(stre): print(stre, end="") if __name__=="__main__": while True: print(Inference.generate(input(">>> "), max_length=100, temperature=0.8, silent=True))