Norod78 commited on
Commit
a5f152a
1 Parent(s): cbeb2d3

Upload TinyStories-3M-val-Hebrew-inference.py

Browse files
TinyStories-3M-val-Hebrew-inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ logging.basicConfig(
9
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
10
+ datefmt="%m/%d/%Y %H:%M:%S",
11
+ level=logging.INFO,
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
+ #model_id = "./TinyStories-3M-val-Hebrew"
16
+ model_id = "Norod78/TinyStories-3M-val-Hebrew"
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ #model = AutoModelForCausalLM.from_pretrained("./Hebrew_GPT3_XL", from_tf=True)
20
+ model = AutoModelForCausalLM.from_pretrained(model_id)
21
+
22
+ #prompt_text = "אתמול, בדרך הביתה, גיליתי ש"
23
+ #prompt_text = "פעם, לפני ש"
24
+ #prompt_text = "הסוד השמור ביותר של תעשיית היופי"
25
+ #prompt_text = "<|startoftext|>"
26
+ prompt_text = "\n"
27
+ stop_token = "<|endoftext|>"
28
+ new_lines = "\n\n\n"
29
+ seed = 1000
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
33
+
34
+ logger.info(f"device: {device}, n_gpu: {n_gpu}")
35
+
36
+ np.random.seed(seed)
37
+ torch.manual_seed(seed)
38
+ if n_gpu > 0:
39
+ torch.cuda.manual_seed_all(seed)
40
+
41
+ model.to(device)
42
+ #model.half()
43
+
44
+ def process_output_sequences(output_sequences):
45
+ # Remove the batch dimension when returning multiple sequences
46
+ if len(output_sequences.shape) > 2:
47
+ output_sequences.squeeze_()
48
+
49
+ #generated_sequences = []
50
+
51
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
52
+ print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
53
+ generated_sequence = generated_sequence.tolist()
54
+ # Decode text
55
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
56
+ text = text.replace("<|startoftext|>","").replace(" ; ", "\n")
57
+ # Remove all text after the stop token
58
+ text = text[: text.find(stop_token) if stop_token else None]
59
+ # Remove all text after 3 newlines
60
+ text = text[: text.find(new_lines) if new_lines else None]
61
+ print(text)
62
+ #generated_sequences.append(text)
63
+ #print(generated_sequences)
64
+ print("------")
65
+
66
+
67
+ def encode_prompt(text):
68
+ encoded_prompt = tokenizer.encode(
69
+ text, add_special_tokens=True, return_tensors="pt")
70
+ encoded_prompt = encoded_prompt.to(device)
71
+ if encoded_prompt.size()[-1] == 0:
72
+ input_ids = None
73
+ else:
74
+ input_ids = encoded_prompt
75
+ return input_ids
76
+
77
+ input_ids = encode_prompt(prompt_text)
78
+ input_ids_len = input_ids.size()[-1]
79
+ max_length = input_ids_len + 192
80
+ if max_length > 1023:
81
+ max_length = 1023
82
+
83
+ output_sequences = model.generate(
84
+ input_ids=input_ids,
85
+ max_length=max_length,
86
+ temperature=0.98,
87
+ top_k=40,
88
+ top_p=0.92,
89
+ repetition_penalty=2.0,
90
+ do_sample=True,
91
+ num_return_sequences=5
92
+ )
93
+
94
+ process_output_sequences(output_sequences)
95
+
96
+