NorHsangPha commited on
Commit
77e720d
β€’
1 Parent(s): cb45693

Initial: initial commit

Browse files
Files changed (3) hide show
  1. app.py +30 -0
  2. gpt2.py +108 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gpt2 import generate_text, GENERATE_EXAMPLES
3
+
4
+ gpt_generate = gr.Interface(
5
+ fn=generate_text,
6
+ inputs=[
7
+ gr.Textbox(label="Input text"),
8
+ gr.Dropdown(
9
+ [
10
+ "sample_outputs",
11
+ "greedy_search",
12
+ "beem_search",
13
+ "top_k_search",
14
+ "top_p_search",
15
+ ],
16
+ label="Search method",
17
+ value="sample_outputs",
18
+ ),
19
+ ],
20
+ outputs=gr.Textbox(label="Generated text"),
21
+ examples=GENERATE_EXAMPLES,
22
+ title="GPT-2 Text generator Demo",
23
+ description="Generate text using GPT-2.",
24
+ allow_flagging="never",
25
+ )
26
+
27
+ with gr.Blocks() as demo:
28
+ gpt_generate.render()
29
+
30
+ demo.launch()
gpt2.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+
6
+ if torch.cuda.is_available():
7
+ device = torch.device("cuda")
8
+ elif (
9
+ hasattr(torch.backends, "mps")
10
+ and torch.backends.mps.is_available()
11
+ and torch.backends.mps.is_built()
12
+ ):
13
+ device = torch.device("mps")
14
+ else:
15
+ device = torch.device("cpu")
16
+
17
+ print(f"running device: {device}")
18
+ auth_token = os.environ.get("TOKEN_READ_SECRET") or True
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ "NorHsangPha/shan_gpt2_news", token=auth_token
22
+ )
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ "NorHsangPha/shan_gpt2_news", pad_token_id=tokenizer.eos_token_id, token=auth_token
25
+ ).to(device)
26
+
27
+
28
+ def greedy_search(model_inputs, max_new_tokens):
29
+ greedy_output = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
30
+
31
+ return tokenizer.decode(greedy_output[0], skip_special_tokens=True)
32
+
33
+
34
+ def beem_search(model_inputs, max_new_tokens):
35
+ beam_output = model.generate(
36
+ **model_inputs,
37
+ max_new_tokens=max_new_tokens,
38
+ num_beams=5,
39
+ no_repeat_ngram_size=2, #
40
+ num_return_sequences=5, #
41
+ early_stopping=True,
42
+ )
43
+
44
+ return tokenizer.decode(beam_output[0], skip_special_tokens=True)
45
+
46
+
47
+ def sample_outputs(model_inputs, max_new_tokens):
48
+ sample_output = model.generate(
49
+ **model_inputs,
50
+ max_new_tokens=max_new_tokens,
51
+ do_sample=True,
52
+ top_k=0,
53
+ temperature=0.6,
54
+ )
55
+
56
+ return tokenizer.decode(sample_output[0], skip_special_tokens=True)
57
+
58
+
59
+ def top_k_search(model_inputs, max_new_tokens):
60
+ top_k_output = model.generate(
61
+ **model_inputs, max_new_tokens=max_new_tokens, do_sample=True, top_k=50
62
+ )
63
+
64
+ return tokenizer.decode(top_k_output[0], skip_special_tokens=True)
65
+
66
+
67
+ def top_p_search(model_inputs, max_new_tokens):
68
+ top_p_output = model.generate(
69
+ **model_inputs,
70
+ max_new_tokens=max_new_tokens,
71
+ do_sample=True,
72
+ top_p=0.92,
73
+ top_k=0,
74
+ )
75
+
76
+ return tokenizer.decode(top_p_output[0], skip_special_tokens=True)
77
+
78
+
79
+ def generate_text(input_text, search_method="sample_outputs"):
80
+ model_inputs = tokenizer(input_text, return_tensors="pt").to(device)
81
+ max_new_tokens = 120
82
+
83
+ match search_method:
84
+ case "greedy_search":
85
+ text = greedy_search(model_inputs, max_new_tokens)
86
+
87
+ case "beem_search":
88
+ text = beem_search(model_inputs, max_new_tokens)
89
+
90
+ case "top_k_search":
91
+ text = top_k_search(model_inputs, max_new_tokens)
92
+
93
+ case "top_p_search":
94
+ text = top_p_search(model_inputs, max_new_tokens)
95
+
96
+ case _:
97
+ text = sample_outputs(model_inputs, max_new_tokens)
98
+
99
+ return text
100
+
101
+
102
+ GENERATE_EXAMPLES = [
103
+ ["α€™α‚‚α€Ία‚‡α€žα€―α€„α€ΊαΆα‚ƒα‚ˆ", "sample_outputs"],
104
+ ["α€•α’α€„α€Ία€α€­α€―α΅α€Ία€Έα€žα€­α€―α΅α€Ία€Έα€žα€­α€°α€α€Ί", "greedy_search"],
105
+ ["α€•α’α€„α€Ία€α€­α€―α΅α€Ία€Έα€žα€­α€―α΅α€Ία€Έα€žα€­α€°α€α€Ί", "top_k_search"],
106
+ ["α€•α’α€„α€Ία€α€­α€―α΅α€Ία€Έα€žα€­α€―α΅α€Ία€Έα€žα€­α€°α€α€Ί", "top_p_search"],
107
+ ["α€•α’α€„α€Ία€α€­α€―α΅α€Ία€Έα€žα€­α€―α΅α€Ία€Έα€žα€­α€°α€α€Ί", "beem_search"],
108
+ ]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ torchaudio