import argparse import logging import numpy as np import torch from transformers import AutoTokenizer, AutoModelForCausalLM logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) #model_id = "./TinyStories-3M-val-Hebrew" model_id = "Norod78/TinyStories-3M-val-Hebrew" tokenizer = AutoTokenizer.from_pretrained(model_id) #model = AutoModelForCausalLM.from_pretrained("./Hebrew_GPT3_XL", from_tf=True) model = AutoModelForCausalLM.from_pretrained(model_id) #prompt_text = "אתמול, בדרך הביתה, גיליתי ש" #prompt_text = "פעם, לפני ש" #prompt_text = "הסוד השמור ביותר של תעשיית היופי" #prompt_text = "<|startoftext|>" prompt_text = "\n" stop_token = "<|endoftext|>" new_lines = "\n\n\n" seed = 1000 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count() logger.info(f"device: {device}, n_gpu: {n_gpu}") np.random.seed(seed) torch.manual_seed(seed) if n_gpu > 0: torch.cuda.manual_seed_all(seed) model.to(device) #model.half() def process_output_sequences(output_sequences): # Remove the batch dimension when returning multiple sequences if len(output_sequences.shape) > 2: output_sequences.squeeze_() #generated_sequences = [] for generated_sequence_idx, generated_sequence in enumerate(output_sequences): print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") generated_sequence = generated_sequence.tolist() # Decode text text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) text = text.replace("<|startoftext|>","").replace(" ; ", "\n") # Remove all text after the stop token text = text[: text.find(stop_token) if stop_token else None] # Remove all text after 3 newlines text = text[: text.find(new_lines) if new_lines else None] print(text) #generated_sequences.append(text) #print(generated_sequences) print("------") def encode_prompt(text): encoded_prompt = tokenizer.encode( text, add_special_tokens=True, return_tensors="pt") encoded_prompt = encoded_prompt.to(device) if encoded_prompt.size()[-1] == 0: input_ids = None else: input_ids = encoded_prompt return input_ids input_ids = encode_prompt(prompt_text) input_ids_len = input_ids.size()[-1] max_length = input_ids_len + 192 if max_length > 1023: max_length = 1023 output_sequences = model.generate( input_ids=input_ids, max_length=max_length, temperature=0.98, top_k=40, top_p=0.92, repetition_penalty=2.0, do_sample=True, num_return_sequences=5 ) process_output_sequences(output_sequences)