tuanle commited on
Commit
7989304
1 Parent(s): 5e92cc6

added generate.py

Browse files
Files changed (1) hide show
  1. generate.py +45 -0
generate.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+
5
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+ tokenizer = AutoTokenizer.from_pretrained("tuanle/VN-News-GPT2", cache_dir="cache/")
7
+ model = AutoModelForCausalLM.from_pretrained("tuanle/VN-News-GPT2", cache_dir="cache/").to(device)
8
+ print("Loading model...")
9
+ print("Model is ready to serve...")
10
+
11
+ def generate(category, headline,
12
+ min_len = 60,
13
+ max_len = 768,
14
+ num_beams = 5,
15
+ num_return_sequences = 3,
16
+ top_k = 50,
17
+ top_p = 1):
18
+ """
19
+ top_p: If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
20
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
21
+ num_beams: Number of beams for beam search. 1 means no beam search.
22
+ """
23
+ text = f"<|startoftext|> {category} <|headline|> {headline}"
24
+
25
+ input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
26
+
27
+ sample_outputs = model.generate(input_ids,
28
+ do_sample=True,
29
+ max_length=max_len,
30
+ min_length=min_len,
31
+ # temperature = .8,
32
+ top_k= top_k,
33
+ top_p = top_p,
34
+ num_beams= num_beams,
35
+ early_stopping= True,
36
+ no_repeat_ngram_size= 2 ,
37
+ num_return_sequences= num_return_sequences)
38
+
39
+ outputs = []
40
+ for i, sample_output in enumerate(sample_outputs):
41
+ temp = tokenizer.decode(sample_output.tolist())
42
+ print(f">> Generated text {i+1}\n\n{temp}")
43
+ print('\n---')
44
+ outputs.append(temp)
45
+ return outputs