lewiswu1209 commited on
Commit
147e546
β€’
0 Parent(s):

initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +27 -0
  2. README.md +13 -0
  3. app.py +55 -0
  4. requirements.txt +2 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Gpt2 Chinese Couplet
3
+ emoji: πŸš€
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.0.26
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ import gradio as gr
5
+ import torch.nn.functional as F
6
+
7
+ from transformers import BertTokenizer, GPT2LMHeadModel
8
+
9
+ tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-couplet")
10
+ model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-couplet")
11
+ model.eval()
12
+
13
+ def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ):
14
+ assert logits.dim() == 1
15
+ top_k = min( top_k, logits.size(-1) )
16
+ if top_k > 0:
17
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
18
+ logits[indices_to_remove] = filter_value
19
+ if top_p > 0.0:
20
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
21
+ cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1 )
22
+ sorted_indices_to_remove = cumulative_probs > top_p
23
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24
+ sorted_indices_to_remove[..., 0] = 0
25
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
26
+ logits[indices_to_remove] = filter_value
27
+ return logits
28
+
29
+ def generate(input_text):
30
+ input_ids = [tokenizer.cls_token_id]
31
+ input_ids.extend( tokenizer.encode(input_text + "-", add_special_tokens=False) )
32
+ input_ids = torch.tensor( [input_ids] )
33
+
34
+ generated = []
35
+ for _ in range(100):
36
+ output = model(input_ids)
37
+
38
+ next_token_logits = output.logits[0, -1, :]
39
+ next_token_logits[ tokenizer.convert_tokens_to_ids('[UNK]') ] = -float('Inf')
40
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=8, top_p=1)
41
+ next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )
42
+ if next_token == tokenizer.sep_token_id:
43
+ break
44
+ generated.append( next_token.item() )
45
+ input_ids = torch.cat( (input_ids, next_token.unsqueeze(0)), dim=1 )
46
+
47
+ return "".join( tokenizer.convert_ids_to_tokens(generated) )
48
+
49
+ if __name__ == "__main__":
50
+
51
+ gr.Interface(
52
+ fn=generate,
53
+ inputs="text",
54
+ outputs="text"
55
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers