ar-houwei-chou commited on
Commit
6aee98f
1 Parent(s): ed402eb
Files changed (11) hide show
  1. README.md +4 -3
  2. app.py +2 -0
  3. app1.py +198 -0
  4. dataset.py +271 -0
  5. helper.py +79 -0
  6. model.py +131 -0
  7. options.py +47 -0
  8. requirements.txt +6 -0
  9. transformer.py +222 -0
  10. vocab.cond.vocab +31 -0
  11. vocab.py +41 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Test3
3
- emoji: 📈
4
  colorFrom: red
5
- colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.17.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Kobedemo
3
+ emoji: 🐢
4
  colorFrom: red
5
+ colorTo: red
6
  sdk: streamlit
7
  sdk_version: 1.17.0
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.8.12
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import streamlit as st
2
+ st.write('You have selected:')
app1.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ import argparse
3
+ import random
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from dataset import KobeDataModule
8
+ from model import KobeModel
9
+ from options import add_args, add_options
10
+ from typing import List
11
+ import warnings
12
+ import logging
13
+ from transformers.models.bert.tokenization_bert import BertTokenizer
14
+ import sentencepiece as spm
15
+
16
+
17
+ logging.getLogger("lightning").setLevel(logging.ERROR)
18
+ def fxn():
19
+ warnings.warn("deprecated", DeprecationWarning)
20
+
21
+ with warnings.catch_warnings():
22
+ warnings.simplefilter("ignore")
23
+ fxn()
24
+
25
+
26
+
27
+ parser = argparse.ArgumentParser()
28
+ add_options(parser)
29
+ args = parser.parse_args()
30
+ args.vocab_file= "bert-base-chinese"
31
+ args.cond_vocab_file = "./vocab.cond.model"
32
+ add_args(args)
33
+
34
+
35
+
36
+ #model = KobeModel(args)
37
+
38
+
39
+
40
+
41
+ class Example:
42
+ title_token_ids: List[int]
43
+ condition_token_ids: List[int]
44
+ fact_token_ids: List[int]
45
+
46
+ def __init__(self, title_token_ids:List[int], condition_token_ids: List[int], fact_token_ids: List[int]):
47
+ self.title_token_ids = title_token_ids
48
+ self.condition_token_ids = condition_token_ids
49
+ self.fact_token_ids = fact_token_ids
50
+
51
+ text_tokenizer = BertTokenizer.from_pretrained(args.vocab_file)
52
+ cond_tokenizer = spm.SentencePieceProcessor()
53
+ cond_tokenizer.Load(args.cond_vocab_file)
54
+
55
+
56
+
57
+
58
+
59
+ #model = model.load_from_checkpoint("/root/kobe-v2/1ja19m5t/checkpoints/epoch=19-step=66080.ckpt", args=args)
60
+ #model = model.load_from_checkpoint("/root/kobe-v2/37ht1cvz/checkpoints/epoch=11-step=396384.ckpt", args=args)
61
+
62
+ trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=-1)
63
+ """
64
+
65
+
66
+ import streamlit as st
67
+
68
+ st.write("Most appearing words including stopwords")
69
+
70
+ """
71
+ choice = st.selectbox(
72
+
73
+ 'Select the items you want?',
74
+
75
+ ('Pen','Pencil','Eraser','Sharpener','Notebook'))
76
+
77
+
78
+ input1 = st.selectbox(
79
+ 'please choose:',
80
+ ("1.天猫直发百诺碳纤维三脚架单反相机三角架c2690tb1摄影脚架",
81
+ "2.dacom飞鱼p10运动型跑步蓝牙耳机入耳式头戴挂耳塞式7级防水苹果安卓手机通用可接听电话音乐篮",
82
+ "3.coach蔻驰贝壳包pvc单肩手提斜挎大号贝壳包女包5828",
83
+ "4.highcook韩库韩国进口蓝宝石近无烟炒锅家用不粘锅电磁炉炒菜锅",
84
+ "5.欧式复古亚麻布料沙发面料定做飘窗垫窗台垫榻榻米垫抱枕diy布艺",
85
+ "6.飞利浦电动剃须刀sp9851充电式带多功能理容配件和智能清洁系统",
86
+ "7.不锈钢牛排刀叉西餐餐具全套筷子刀叉勺三件套欧式加厚24件礼盒装",
87
+ "8.香百年汽车香水挂件车载香水香薰悬挂吊坠车用车内装饰挂饰精油",
88
+ "9.迪士尼小学生书包儿童男孩13一46年级美国队长男童减负12周岁男",
89
+ "10.半饱良味潮汕猪肉脯宅人食堂潮汕小吃特产碳烤猪肉干120g"))
90
+
91
+ st.write('You selected:', input1)
92
+
93
+
94
+ title =""
95
+ fact = ""
96
+
97
+ if input1==1:
98
+ title= "天猫直发百诺碳纤维三脚架单反相机三角架c2690tb1摄影脚架"
99
+ fact = "太字节Terabyte,计算机存储容量单位,也常用TB来表示。百诺公司创建于1996年,早期与日本合作,后通过自身技术创新与努力,逐渐在国内外抑尘设备行业赢得一席单反就是指单镜头反光,即SLRSingleLensReflex,单反相机就是拥有这个功能的相机。技巧拍摄往往都离不开三脚架的帮助,如夜景拍摄、微距拍摄等方面。"
100
+
101
+
102
+ if input1==2:
103
+ title="dacom飞鱼p10运动型跑步蓝牙耳机入耳式头戴挂耳塞式7级防水苹果安卓手机通用可接听电话音乐篮牙"
104
+ fact = "移动电话,或称为无线电话,通常称为手机,原本只是一种通讯工具,早期又有大哥大的俗称,是可以在较广范围运动型是德国精神病学家克雷奇默提出的身体类型之一。跑步,是指陆生动物使用足部,移动最快捷的方法。蓝牙耳机就是将蓝牙技术应用在免持耳机上,让使用者可以免除恼人电线的牵绊,自在地以各种方式轻松通话。"
105
+
106
+
107
+ if input1==3:
108
+ title="coach蔻驰贝壳包pvc单肩手提斜挎大号贝壳包女包5828"
109
+ fact = "聚氯乙烯,英文简称PVCPolyvinylchloride,是氯乙烯单体vinylchloridem女包,这个名词是箱包的性别分类衍生词。贝壳包beikebao女士包种类的一种,因为其外形酷似贝壳的外形而得名。蔻驰为美国经典皮件品牌COACH,一像以简洁、耐用的风格特色赢得消费者的喜爱。"
110
+
111
+
112
+ if input1==4:
113
+ title= "highcook韩库韩国进口蓝宝石近无烟炒锅家用不粘锅电磁炉炒菜锅"
114
+ fact = "蓝宝石,是刚玉宝石中除红色的红宝石之外,其它颜色刚玉宝石的通称,主要成分是氧化铝Al2O3。电磁炉又称为电磁灶,1957年第一台家用电磁炉诞生于德国。家用是汉语词汇,出自管子权修,解释为家庭日常使用的。不粘锅即做饭不会粘锅底的锅,是因为锅底采用了不粘涂层,常见的、不粘性能最好的有特���龙涂层和陶瓷涂层。"
115
+
116
+
117
+ if input1==5:
118
+ title= "欧式复古亚麻布料沙发面料定做飘窗垫窗台垫榻榻米垫抱枕diy布艺"
119
+ fact = "沙发是个外来词,根据英语单词sofa音译而来。面料就是用来制作服装的材料。飘窗垫,就是放在飘窗的台面上的垫子。复古与怀旧,有时候很难区分。"
120
+
121
+
122
+ if input1==6:
123
+ title= "飞利浦电动剃须刀sp9851充电式带多功能理容配件和智能清洁系统"
124
+ fact = "能够完成一种或者几种生理功能的多个器官按照一定的次序组合在一起的结构叫做系统。配件,指装配机械的零件或部件;也指损坏后重新安装上的零件或部件。清洁是由奥利维耶阿萨亚斯执导,张曼玉、尼克诺尔蒂主演的剧情片,于2004年9月1日在法国上映。飞利浦,1891年成立于荷兰,主要生产照明、家庭电器、医疗系统方面的产品。"
125
+
126
+ if input1==7:
127
+ title= "不锈钢牛排刀叉西餐餐具全套筷子刀叉勺三件套欧式加厚24件礼盒装"
128
+ fact = "西餐餐具具体有大盘子、小盘子、浅碟、深碟、吃沙拉用的叉子、叉肉用的叉子、喝汤用的汤匙、吃甜点用的汤匙三件咳嗽,贫穷与爱情触不到的恋人...不锈钢指耐空气、蒸汽、水等弱腐蚀介质和酸、碱、盐等化学浸蚀性介质腐蚀的钢,又称不锈耐酸钢。2,4D丁酯,无色油状液体。"
129
+
130
+ if input1==8:
131
+ title= "香百年汽车香水挂件车载香水香薰悬挂吊坠车用车内装饰挂饰精油"
132
+ fact = "悬挂系统是汽车的车架与车桥或车轮之间的一切传力连接装置的总称,其作用是传递作用在车轮和车架之间的力和吊坠,一种首饰,配戴在脖子上的饰品,多为金属制,特别是不锈钢制和银制,也有矿石、水晶、玉石等制的,主汽车香水AutoPerfume是一种混合了香精油、固定剂与酒精的液体,用来让汽车车内拥有持久且悦人的精油是从植物的花、叶、茎、根或果实中,通过水蒸气蒸馏法、挤压法、冷浸法或溶剂提取法提炼萃取的挥发性芳"
133
+
134
+ if input1==9:
135
+ title = "迪士尼小学生书包儿童男孩13一46年级美国队长男童减负12周岁男"
136
+ fact= "美国队长是每一男孩心中的英雄人物,迪士尼美国队长款的小学生书包,按照美国队长的防护盾牌设计,泛着丝丝银光,帅气有型,而且还有很多英雄款式哦脊背处采用柔软舒适的脊椎防护设计,减轻孩子的背部压力。而且前方盾牌还能拆卸下来,当做斜挎包使用,满足淘气小男孩的英雄梦。"
137
+
138
+ if input1==10:
139
+ title = "半饱良味潮汕猪肉脯宅人食堂潮汕小吃特产碳烤猪肉干120g"
140
+ fact = "潮汕,不是潮州潮州一词始于隋文帝开皇十年,距今不到两千年。猪肉脯是一种用猪肉经腌制、烘烤的片状肉制品,食用方便、制作考究、美味可口、耐贮藏和便于运输的中式传统发达国家都有全国统一的急救电话号码。特产指某地特有的或特别著名的产品,有文化内涵或历史,亦指只有在某地才生产的一种产品。"
141
+
142
+
143
+ input2 = st.selectbox(
144
+ 'please choose category:',
145
+ ("1: 家庭主妇",
146
+ "2: 烹饪达人",
147
+ "3: 买鞋控",
148
+ "4: 数码达人",
149
+ "5: 吃货",
150
+ "6: 爱包人",
151
+ "7: 高富帅",
152
+ ))
153
+
154
+ st.write('You selected:', input2)
155
+ aspect = "<"+str(input2)+">"
156
+
157
+
158
+
159
+ input3 = st.selectbox(
160
+ 'please choose aspect:',
161
+ ("1: appearance",
162
+ "2: texture",
163
+ "3: function"))
164
+
165
+ st.write('You selected:', input3)
166
+
167
+ cond = ""
168
+ if input3==1:
169
+ cond="<a>"
170
+ if input3==2:
171
+ cond="<b>"
172
+ if input3==3:
173
+ cond="<c>"
174
+
175
+ #cond = cond+" "+aspect
176
+ cond = aspect+" "+cond
177
+ #print(title)
178
+ #print(fact)
179
+ #print(cond)
180
+ tokenizer = text_tokenizer
181
+ title_token_ids=tokenizer.encode(title, add_special_tokens=False)
182
+ condition_token_ids=cond_tokenizer.EncodeAsIds(cond)
183
+ fact_token_ids=tokenizer.encode(fact, add_special_tokens=False)
184
+
185
+ e = Example(title_token_ids, condition_token_ids, fact_token_ids)
186
+
187
+
188
+ dm = KobeDataModule(
189
+ [e],
190
+ args.text_vocab_path,
191
+ args.max_seq_len,
192
+ 1,
193
+ 1,
194
+ )
195
+ for d in dm.test_dataloader():
196
+ st.write("result:")
197
+ st.write(''.join(model.test_step(d ,1)).replace(" ",""))
198
+ """
dataset.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from dataclasses import dataclass
3
+ from typing import List
4
+
5
+ import pytorch_lightning as pl
6
+ import sentencepiece as spm
7
+ import torch
8
+ from torch.functional import Tensor
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torch.utils.data.dataloader import DataLoader
11
+
12
+
13
+
14
+ @dataclass
15
+ class Example:
16
+ title_token_ids: List[int]
17
+ description_token_ids: List[int]
18
+ condition_token_ids: List[int]
19
+ fact_token_ids: List[int]
20
+ description: str
21
+ title: str
22
+
23
+
24
+ @dataclass
25
+ class TensorDict:
26
+ def detach(self):
27
+ detached_dict = {
28
+ field: getattr(self, field).detach()
29
+ if isinstance(getattr(self, field), torch.Tensor)
30
+ else getattr(self, field)
31
+ for field in self.__dataclass_fields__
32
+ }
33
+ return self.__class__(**detached_dict)
34
+
35
+ def cpu(self):
36
+ detached_dict = {
37
+ field: getattr(self, field).cpu()
38
+ if isinstance(getattr(self, field), torch.Tensor)
39
+ else getattr(self, field)
40
+ for field in self.__dataclass_fields__
41
+ }
42
+ return self.__class__(**detached_dict)
43
+
44
+
45
+ @dataclass
46
+ class Batched(TensorDict):
47
+ # Source
48
+ title_token_ids: torch.Tensor
49
+ title_token_ids_mask: torch.Tensor
50
+ # Attribute Fusion
51
+ cond_title_token_ids: torch.Tensor
52
+ cond_title_token_ids_mask: torch.Tensor
53
+ # Knowledge Incorporation
54
+ fact_token_ids: torch.Tensor
55
+ fact_token_ids_mask: torch.Tensor
56
+ title_fact_token_ids: torch.Tensor
57
+ title_fact_token_ids_mask: torch.Tensor
58
+ # Attribute Fusion + Knowledge Incorporation
59
+ cond_title_fact_token_ids: torch.Tensor
60
+ cond_title_fact_token_ids_mask: torch.Tensor
61
+ # Target
62
+ #description_token_ids: torch.Tensor
63
+ #description_token_ids_mask: torch.Tensor
64
+ #descriptions: List[str]
65
+ #titles: List[str]
66
+
67
+
68
+
69
+ @dataclass
70
+ class EncodedBatch(TensorDict):
71
+ context_encodings: torch.Tensor
72
+ context_encodings_mask: torch.Tensor
73
+
74
+
75
+ @dataclass
76
+ class DecodedBatch:
77
+ loss: float
78
+ acc: float
79
+ generated: List[str]
80
+ descriptions: List[str]
81
+ titles: List[str]
82
+
83
+
84
+ def from_processed(url: str, train=False):
85
+ urls = sorted(glob.glob(url))
86
+ def my_split_by_worker(urls):
87
+ wi = torch.utils.data.get_worker_info()
88
+ if wi is None:
89
+ return urls
90
+ else:
91
+ return urls[wi.id::wi.num_workers]
92
+ def my_split_by_node(urls):
93
+ node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
94
+ return urls[node_id::node_count]
95
+ if train:
96
+
97
+ return (
98
+ wds.WebDataset(urls)
99
+ #wds.WebDataset(urls,nodesplitter=my_split_by_node)
100
+ #wds.WebDataset(urls,nodesplitter=wds.split_by_node)
101
+ .shuffle(size=10000000, initial=100000)
102
+ .decode()
103
+ .map(lambda d: Example(**d["json"]))
104
+ )
105
+ else:
106
+ print(list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"])))[0])
107
+ sys.exit()
108
+ return list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"])))
109
+ #return list(wds.WebDataset(urls, nodesplitter=my_split_by_node).decode().map(lambda d: Example(**d["json"])))
110
+ #return list(wds.WebDataset(urls, nodesplitter=wds.split_by_node).decode().map(lambda d: Example(**d["json"])))
111
+
112
+
113
+ def get_collate_fn(text_vocab_size: int, max_seq_length: int):
114
+ def collate_fn(examples: List[Example]) -> Batched:
115
+ from kobe.data.vocab import BOS_ID, EOS_ID
116
+
117
+ title_token_ids = pad_sequence(
118
+ [
119
+ torch.tensor(
120
+ [BOS_ID] + e.title_token_ids[: max_seq_length - 2] + [EOS_ID]
121
+ )
122
+ for e in examples
123
+ ]
124
+ )
125
+ fact_token_ids = pad_sequence(
126
+ [
127
+ torch.tensor(
128
+ [BOS_ID] + e.fact_token_ids[: max_seq_length - 2] + [EOS_ID]
129
+ )
130
+ for e in examples
131
+ ]
132
+ )
133
+ """
134
+ description_token_ids = pad_sequence(
135
+ [
136
+ torch.tensor(
137
+ [BOS_ID] + e.description_token_ids[: max_seq_length - 2] + [EOS_ID]
138
+ )
139
+ for e in examples
140
+ ]
141
+ )
142
+ """
143
+ cond_title_token_ids = pad_sequence(
144
+ [
145
+ torch.tensor(
146
+ (
147
+ [BOS_ID]
148
+ + [
149
+ cond_id + text_vocab_size
150
+ for cond_id in e.condition_token_ids
151
+ ]
152
+ + e.title_token_ids
153
+ )[: max_seq_length - 1]
154
+ + [EOS_ID]
155
+ )
156
+ for e in examples
157
+ ]
158
+ )
159
+ title_fact_token_ids = pad_sequence(
160
+ [
161
+ torch.tensor(
162
+ ([BOS_ID] + e.title_token_ids + [EOS_ID] + e.fact_token_ids)[
163
+ : max_seq_length - 1
164
+ ]
165
+ + [EOS_ID]
166
+ )
167
+ for e in examples
168
+ ]
169
+ )
170
+ cond_title_fact_token_ids = pad_sequence(
171
+ [
172
+ torch.tensor(
173
+ (
174
+ [BOS_ID]
175
+ + [
176
+ cond_id + text_vocab_size
177
+ for cond_id in e.condition_token_ids
178
+ ]
179
+ + e.title_token_ids
180
+ + [EOS_ID]
181
+ + e.fact_token_ids
182
+ )[: max_seq_length - 1]
183
+ + [EOS_ID]
184
+ )
185
+ for e in examples
186
+ ]
187
+ )
188
+ #descriptions = [e.description for e in examples]
189
+ #titles = [e.title for e in examples]
190
+ return Batched(
191
+ title_token_ids=title_token_ids,
192
+ title_token_ids_mask=(title_token_ids == 0).T,
193
+ fact_token_ids=fact_token_ids,
194
+ fact_token_ids_mask=(fact_token_ids == 0).T,
195
+ cond_title_token_ids=cond_title_token_ids,
196
+ cond_title_token_ids_mask=(cond_title_token_ids == 0).T,
197
+ title_fact_token_ids=title_fact_token_ids,
198
+ title_fact_token_ids_mask=(title_fact_token_ids == 0).T,
199
+ cond_title_fact_token_ids=cond_title_fact_token_ids,
200
+ cond_title_fact_token_ids_mask=(cond_title_fact_token_ids == 0).T,
201
+ #description_token_ids="",
202
+ #description_token_ids_mask=(description_token_ids == 0).T,
203
+ #descriptions="",
204
+ #titles="",
205
+ )
206
+
207
+ return collate_fn
208
+
209
+
210
+ class KobeDataModule(pl.LightningDataModule):
211
+ def __init__(
212
+ self,
213
+ test_data: str,
214
+ vocab_path: str,
215
+ max_seq_length: int,
216
+ batch_size: int,
217
+ num_workers: int,
218
+ ):
219
+ super().__init__()
220
+ self.test_data = test_data
221
+ self.max_seq_length = max_seq_length
222
+ self.batch_size = batch_size
223
+ self.num_workers = num_workers
224
+ self.text_vocab_size = helpers.get_bert_vocab_size(vocab_path)
225
+
226
+
227
+ """
228
+ def train_dataloader(self):
229
+ return DataLoader(
230
+ self.train,
231
+ batch_size=self.batch_size,
232
+ num_workers=self.num_workers,
233
+ collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
234
+ )
235
+
236
+ def val_dataloader(self):
237
+ return DataLoader(
238
+ self.valid,
239
+ batch_size=self.batch_size,
240
+ num_workers=self.num_workers,
241
+ collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
242
+ )
243
+ """
244
+ def test_dataloader(self):
245
+ return DataLoader(
246
+ self.test_data,
247
+ batch_size=self.batch_size,
248
+ num_workers=self.num_workers,
249
+ collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ dm = KobeDataModule(
255
+ train_data="saved/processed/train-*.tar",
256
+ valid_data="saved/processed/valid.tar",
257
+ test_data="saved/processed/test.tar",
258
+ vocab_path="bert-base-chinese",
259
+ max_seq_length=512,
260
+ batch_size=32,
261
+ num_workers=8,
262
+ )
263
+ dm.setup("test")
264
+ max_len = 0
265
+ from tqdm import tqdm
266
+
267
+ tqdm_iter = tqdm(dm.test_dataloader())
268
+ for batch in tqdm_iter:
269
+ max_len = max(max_len, batch.cond_title_fact_token_ids.shape[0])
270
+ max_len = max(max_len, batch.description_token_ids.shape[0])
271
+ tqdm_iter.set_description(f"max len = {max_len}")
helper.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers.models.bert.tokenization_bert import BertTokenizer
5
+
6
+ BASELINE = "baseline"
7
+ KOBE_ATTRIBUTE = "kobe-attr"
8
+ KOBE_KNOWLEDGE = "kobe-know"
9
+ KOBE_FULL = "kobe-full"
10
+
11
+
12
+ def get_bert_vocab_size(vocab_path: str) -> int:
13
+ tokenizer = BertTokenizer.from_pretrained(vocab_path)
14
+ return tokenizer.vocab_size
15
+
16
+
17
+ def get_vocab_size(vocab_path: str) -> int:
18
+ tokenizer = spm.SentencePieceProcessor()
19
+ tokenizer.Load(vocab_path)
20
+ return len(tokenizer)
21
+
22
+
23
+
24
+ # Metrics
25
+ def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float:
26
+ assert logits.dim() == 2
27
+ assert targets.dim() == 1
28
+ pred = logits.argmax(dim=1)
29
+ return (pred == targets).sum().item() / targets.shape[0]
30
+
31
+
32
+ def top_k_top_p_sampling(
33
+ logits, top_k=0, top_p=0.0, temperature=1, filter_value=-float("Inf")
34
+ ) -> int:
35
+ """Sample from a filtered distribution of logits using top-k and/or nucleus (top-p) filtering
36
+ Args:
37
+ logits: logits distribution shape (vocabulary size)
38
+ top_k >0: keep only top k tokens with highest probability (top-k filtering).
39
+ top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
40
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
41
+ """
42
+ logits /= temperature
43
+ assert (
44
+ logits.dim() == 1
45
+ ) # batch size 1 for now - could be updated for more but the code would be less clear
46
+ top_k = min(top_k, logits.size(-1)) # Safety check
47
+ if top_k > 0:
48
+ # Remove all tokens with a probability less than the last token of the top-k
49
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
50
+ logits[indices_to_remove] = filter_value
51
+
52
+ if top_p > 0.0:
53
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
54
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
55
+
56
+ # Remove tokens with cumulative probability above the threshold
57
+ sorted_indices_to_remove = cumulative_probs > top_p
58
+ # Shift the indices to the right to keep also the first token above the threshold
59
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
60
+ sorted_indices_to_remove[..., 0] = 0
61
+
62
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
63
+ logits[indices_to_remove] = filter_value
64
+
65
+ # Sample from the filtered distribution
66
+ probabilities = F.softmax(logits, dim=-1)
67
+ next_token = torch.multinomial(probabilities, 1)
68
+
69
+ return int(next_token.item())
70
+
71
+
72
+ def diversity(tokenized_lines, n=4) -> int:
73
+ """Defined as the unique number of ngrams generated on the test set."""
74
+ n_grams_all = []
75
+ for line in tokenized_lines:
76
+ n_grams = list(zip(*[line[i:] for i in range(n)]))
77
+ n_grams_all += n_grams
78
+
79
+ return len(set(n_grams_all))
model.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import sentencepiece as spm
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ #from sacrebleu.metrics.bleu import BLEU, _get_tokenizer
10
+ from torch import optim
11
+ from torch.nn.init import xavier_uniform_
12
+ from transformers.models.bert.tokenization_bert import BertTokenizer
13
+
14
+ #import wandb
15
+ from dataset import Batched, DecodedBatch
16
+ #from models.scheduler import WarmupDecayLR
17
+ from transformer import Decoder, Encoder
18
+ #from kobe.utils import helpers
19
+
20
+
21
+ class KobeModel(pl.LightningModule):
22
+ def __init__(self, args):
23
+ super(KobeModel, self).__init__()
24
+
25
+ self.encoder = Encoder(
26
+ vocab_size=args.text_vocab_size + args.cond_vocab_size,
27
+ max_seq_len=args.max_seq_len,
28
+ d_model=args.d_model,
29
+ nhead=args.nhead,
30
+ num_layers=args.num_encoder_layers,
31
+ dropout=args.dropout,
32
+ mode=args.mode,
33
+ )
34
+ self.decoder = Decoder(
35
+ vocab_size=args.text_vocab_size,
36
+ max_seq_len=args.max_seq_len,
37
+ d_model=args.d_model,
38
+ nhead=args.nhead,
39
+ num_layers=args.num_decoder_layers,
40
+ dropout=args.dropout,
41
+ )
42
+ self.lr = args.lr
43
+ self.d_model = args.d_model
44
+ self.loss = nn.CrossEntropyLoss(
45
+ reduction="mean", ignore_index=0, label_smoothing=0.1
46
+ )
47
+ self._reset_parameters()
48
+
49
+ self.decoding_strategy = args.decoding_strategy
50
+ self.vocab = BertTokenizer.from_pretrained(args.text_vocab_path)
51
+ #self.bleu = BLEU(tokenize=args.tokenize)
52
+ #self.sacre_tokenizer = _get_tokenizer(args.tokenize)()
53
+ #self.bert_scorer = BERTScorer(lang=args.tokenize, rescale_with_baseline=True)
54
+
55
+ def _reset_parameters(self):
56
+ for p in self.parameters():
57
+ if p.dim() > 1:
58
+ xavier_uniform_(p)
59
+
60
+ def _tokenwise_loss_acc(
61
+ self, logits: torch.Tensor, batch: Batched
62
+ ) -> Tuple[torch.Tensor, float]:
63
+ unmask = ~batch.description_token_ids_mask.T[1:]
64
+ unmasked_logits = logits[unmask]
65
+ unmasked_targets = batch.description_token_ids[1:][unmask]
66
+ #acc = helpers.accuracy(unmasked_logits, unmasked_targets)
67
+ return self.loss(logits.transpose(1, 2), batch.description_token_ids[1:]), 1
68
+
69
+ def training_step(self, batch: Batched, batch_idx: int):
70
+ encoded = self.encoder.forward(batch)
71
+ logits = self.decoder.forward(batch, encoded)
72
+ loss, acc = self._tokenwise_loss_acc(logits, batch)
73
+ self.lr_schedulers().step()
74
+ self.log("train/loss", loss.item())
75
+ self.log("train/acc", acc)
76
+ return loss
77
+
78
+ def _shared_eval_step(self, batch: Batched, batch_idx: int) -> DecodedBatch:
79
+ encoded = self.encoder.forward(batch)
80
+ #logits = self.decoder.forward(batch, encoded)
81
+ #loss, acc = self._tokenwise_loss_acc(logits, batch)
82
+
83
+ preds = self.decoder.predict(
84
+ encoded_batch=encoded, decoding_strategy=self.decoding_strategy
85
+ )
86
+ generated = self.vocab.batch_decode(preds.T.tolist(), skip_special_tokens=True)
87
+ #print(generated)
88
+
89
+ return generated
90
+ return DecodedBatch(
91
+ loss=loss.item(),
92
+ acc=acc,
93
+ generated=generated,
94
+ descriptions=batch.descriptions,
95
+ titles=batch.titles,
96
+ )
97
+
98
+ def validation_step(self, batch, batch_idx):
99
+ return self._shared_eval_step(batch, batch_idx)
100
+
101
+ def test_step(self, batch, batch_idx, dataloader_idx=0):
102
+ return self._shared_eval_step(batch, batch_idx)
103
+
104
+ def _shared_epoch_end(self, outputs: List[DecodedBatch], prefix):
105
+ loss = np.mean([o.loss for o in outputs])
106
+ acc = np.mean([o.acc for o in outputs])
107
+ self.log(f"{prefix}/loss", loss)
108
+ self.log(f"{prefix}/acc", acc)
109
+ print(outputs)
110
+
111
+ generated = [g for o in outputs for g in o.generated]
112
+ references = [r for o in outputs for r in o.descriptions]
113
+ titles = [r for o in outputs for r in o.titles]
114
+
115
+
116
+ # Examples
117
+ columns = ["Generated", "Reference"]
118
+ data = list(zip(generated[:256:16], references[:256:16]))
119
+ table = wandb.Table(data=data, columns=columns)
120
+ self.logger.experiment.log({f"examples/{prefix}": table})
121
+
122
+ def validation_epoch_end(self, outputs):
123
+ self._shared_epoch_end(outputs, "val")
124
+
125
+ def test_epoch_end(self, outputs):
126
+ self._shared_epoch_end(outputs, "test")
127
+
128
+ def configure_optimizers(self):
129
+ optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.98))
130
+ #scheduler = WarmupDecayLR(optimizer, warmup_steps=10000, d_model=self.d_model)
131
+ return [optimizer]
options.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser, Namespace
2
+
3
+ import helpers
4
+
5
+
6
+ def add_options(parser: ArgumentParser):
7
+ # fmt: off
8
+ # Dataset
9
+ parser.add_argument("--train-data", default="saved/processed/train-*.tar", type=str)
10
+ parser.add_argument("--valid-data", default="saved/processed/valid.tar", type=str)
11
+ parser.add_argument("--test-data", default="saved/processed/test.tar", type=str)
12
+ parser.add_argument("--text-vocab-path", default="bert-base-chinese", type=str, help="BertTokenizer used to preprocess the corpus")
13
+ parser.add_argument("--cond-vocab-path", default="./vocab.cond.model", type=str)
14
+ parser.add_argument("--num-workers", default=8, help="Number of data loaders", type=int)
15
+ parser.add_argument("--tokenize", default="zh", help="Tokenization method used to compute sacrebleu, diversity, and BERTScore, defaulted to Chinese", type=str)
16
+
17
+ # Model
18
+ parser.add_argument("--d-model", default=512, type=int)
19
+ parser.add_argument("--nhead", default=8, type=int)
20
+ parser.add_argument("--num-encoder-layers", default=6, type=int)
21
+ parser.add_argument("--num-decoder-layers", default=6, type=int)
22
+ parser.add_argument("--max-seq-len", default=256, type=int)
23
+ parser.add_argument("--mode", default="baseline", type=str, choices=[
24
+ helpers.BASELINE, helpers.KOBE_ATTRIBUTE, helpers.KOBE_KNOWLEDGE, helpers.KOBE_FULL])
25
+
26
+ # Training
27
+ parser.add_argument("--name", default="exp", type=str, help="expeirment name")
28
+ parser.add_argument("--gpu", default=1, type=int)
29
+ parser.add_argument("--grad-clip", default=1.0, type=float, help="clip threshold of gradients")
30
+ parser.add_argument("--epochs", default=30, type=int, help="number of epochs to train")
31
+ parser.add_argument("--patience", default=10, type=int, help="early stopping patience")
32
+ parser.add_argument("--lr", default=1, type=float, help="learning rate")
33
+ parser.add_argument("--dropout", default=0.1, type=float, help="dropout rate")
34
+ parser.add_argument("--batch-size", default=64, type=int)
35
+ parser.add_argument("--seed", default=42, type=int)
36
+
37
+ # Evaluation
38
+ parser.add_argument("--test", action="store_true", help="only do evaluation")
39
+ parser.add_argument("--load-file", required=False, type=str, help="path to the checkpoint (.ckpt) for evaluation")
40
+ parser.add_argument("--decoding-strategy", default="greedy", type=str, choices=["greedy", "nucleus"], help="Whether to use greedy decoding or nucleus sampling (https://arxiv.org/abs/1904.09751)")
41
+
42
+ # fmt: on
43
+
44
+
45
+ def add_args(args: Namespace):
46
+ args.text_vocab_size = helpers.get_bert_vocab_size(args.text_vocab_path)
47
+ args.cond_vocab_size = helpers.get_vocab_size(args.cond_vocab_path)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.10.0
2
+ transformers==4.25.1
3
+ sentencepiece
4
+ pytorch-lightning==1.6.4
5
+
6
+
transformer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from cached_property import cached_property
7
+ from torch.nn.modules.transformer import (
8
+ TransformerDecoder,
9
+ TransformerDecoderLayer,
10
+ TransformerEncoder,
11
+ TransformerEncoderLayer,
12
+ )
13
+
14
+ from dataset import Batched, EncodedBatch
15
+ from vocab import BOS_ID, EOS_ID, PAD_ID
16
+ import helper
17
+
18
+ class PositionalEncoding(nn.Module):
19
+ def __init__(self, dropout, dim, max_len=5000):
20
+ """
21
+ initialization of required variables and functions
22
+ :param dropout: dropout probability
23
+ :param dim: hidden size
24
+ :param max_len: maximum length
25
+ """
26
+ super(PositionalEncoding, self).__init__()
27
+ # positional encoding initialization
28
+ pe = torch.zeros(max_len, dim)
29
+ position = torch.arange(0, max_len).unsqueeze(1)
30
+ # term to divide
31
+ div_term = torch.exp(
32
+ (torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
33
+ )
34
+ # sinusoidal positional encoding
35
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
36
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
37
+ pe = pe.unsqueeze(1)
38
+ self.register_buffer("pe", pe)
39
+ self.dropout = nn.Dropout(p=dropout)
40
+ self.dim = dim
41
+
42
+ def forward(self, emb):
43
+ """
44
+ create positional encoding
45
+ :param emb: word embedding
46
+ :param step: step for decoding in inference
47
+ :return: positional encoding representation
48
+ """
49
+ emb *= math.sqrt(self.dim)
50
+ emb = emb + self.pe[: emb.size(0)] # [len, batch, size]
51
+ emb = self.dropout(emb)
52
+ return emb
53
+
54
+
55
+ class Encoder(nn.Module):
56
+ @staticmethod
57
+ def from_args(args) -> "Encoder":
58
+ return Encoder(
59
+ args.text_vocab_size + args.cond_vocab_size,
60
+ args.max_seq_len,
61
+ args.d_model,
62
+ args.nhead,
63
+ args.num_encoder_layers,
64
+ args.dropout,
65
+ args.mode,
66
+ )
67
+
68
+ def __init__(
69
+ self,
70
+ vocab_size: int,
71
+ max_seq_len: int,
72
+ d_model: int,
73
+ nhead: int,
74
+ num_layers: int,
75
+ dropout: float,
76
+ mode: str,
77
+ ):
78
+ super().__init__()
79
+ self.d_model = d_model
80
+ self.max_seq_len = max_seq_len
81
+ self.input_embedding = nn.Embedding(vocab_size, d_model)
82
+ self.pos_encoder = PositionalEncoding(dropout, d_model)
83
+ encoder_layer = TransformerEncoderLayer(
84
+ d_model, nhead, d_model * 4, dropout, norm_first=True
85
+ )
86
+ self.encoder = TransformerEncoder(
87
+ encoder_layer, num_layers, nn.LayerNorm(d_model)
88
+ )
89
+ self.mode = mode
90
+
91
+ @cached_property
92
+ def device(self):
93
+ return list(self.parameters())[0].device
94
+
95
+ def forward(self, batched: Batched) -> EncodedBatch:
96
+ src, src_key_padding_mask = Encoder._get_input(batched, self.mode)
97
+ src = self.input_embedding(src)
98
+ src = self.pos_encoder(src)
99
+ token_encodings = self.encoder.forward(
100
+ src=src, src_key_padding_mask=src_key_padding_mask
101
+ )
102
+ return EncodedBatch(
103
+ context_encodings=token_encodings,
104
+ context_encodings_mask=src_key_padding_mask,
105
+ )
106
+
107
+ @staticmethod
108
+ def _get_input(batched: Batched, mode: str) -> Tuple[torch.Tensor, torch.Tensor]:
109
+ return {
110
+ helpers.BASELINE: (batched.title_token_ids, batched.title_token_ids_mask),
111
+ helpers.KOBE_ATTRIBUTE: (
112
+ batched.cond_title_token_ids,
113
+ batched.cond_title_token_ids_mask,
114
+ ),
115
+ helpers.KOBE_KNOWLEDGE: (
116
+ batched.title_fact_token_ids,
117
+ batched.title_fact_token_ids_mask,
118
+ ),
119
+ helpers.KOBE_FULL: (
120
+ batched.cond_title_fact_token_ids,
121
+ batched.cond_title_fact_token_ids_mask,
122
+ ),
123
+ }[mode]
124
+
125
+
126
+ class Decoder(nn.Module):
127
+ @staticmethod
128
+ def from_args(args) -> "Decoder":
129
+ return Decoder(
130
+ args.text_vocab_size,
131
+ args.max_seq_len,
132
+ args.d_model,
133
+ args.nhead,
134
+ args.num_encoder_layers,
135
+ args.dropout,
136
+ )
137
+
138
+ def __init__(
139
+ self,
140
+ vocab_size: int,
141
+ max_seq_len: int,
142
+ d_model: int,
143
+ nhead: int,
144
+ num_layers: int,
145
+ dropout: float,
146
+ ):
147
+ super(Decoder, self).__init__()
148
+ self.max_seq_len = max_seq_len
149
+ self.embedding = nn.Embedding(vocab_size, d_model)
150
+ self.pos_encoder = PositionalEncoding(dropout, d_model)
151
+ decoder_layer = TransformerDecoderLayer(
152
+ d_model, nhead, 4 * d_model, dropout, norm_first=True
153
+ )
154
+ self.decoder = TransformerDecoder(
155
+ decoder_layer, num_layers, nn.LayerNorm(d_model)
156
+ )
157
+ self.output = nn.Linear(d_model, vocab_size)
158
+
159
+ def forward(self, batch: Batched, encoded_batch: EncodedBatch) -> torch.Tensor:
160
+ tgt = self.embedding(batch.description_token_ids[:-1])
161
+ tgt = self.pos_encoder(tgt)
162
+ tgt_mask = Decoder.generate_square_subsequent_mask(tgt.shape[0], tgt.device)
163
+ outputs = self.decoder(
164
+ tgt=tgt,
165
+ tgt_mask=tgt_mask,
166
+ tgt_key_padding_mask=batch.description_token_ids_mask[:, :-1],
167
+ memory=encoded_batch.context_encodings,
168
+ memory_key_padding_mask=encoded_batch.context_encodings_mask,
169
+ )
170
+ return self.output(outputs)
171
+
172
+ def predict(self, encoded_batch: EncodedBatch, decoding_strategy: str):
173
+ batch_size = encoded_batch.context_encodings.shape[1]
174
+ tgt = torch.tensor(
175
+ [BOS_ID] * batch_size, device=encoded_batch.context_encodings.device
176
+ ).unsqueeze(dim=0)
177
+ tgt_mask = Decoder.generate_square_subsequent_mask(self.max_seq_len, tgt.device)
178
+ pred_all = []
179
+ for idx in range(self.max_seq_len):
180
+ tgt_emb = self.pos_encoder(self.embedding(tgt))
181
+ outputs = self.decoder(
182
+ tgt_emb,
183
+ tgt_mask=tgt_mask[: idx + 1, : idx + 1],
184
+ memory=encoded_batch.context_encodings,
185
+ memory_key_padding_mask=encoded_batch.context_encodings_mask,
186
+ )
187
+ logits = self.output(outputs[-1])
188
+
189
+ if decoding_strategy == "greedy":
190
+ pred_step = logits.argmax(dim=1).tolist()
191
+ elif decoding_strategy == "nucleus":
192
+ pred_step = [
193
+ helpers.top_k_top_p_sampling(logits[i], top_p=0.95)
194
+ for i in range(batch_size)
195
+ ]
196
+ else:
197
+ raise NotImplementedError
198
+ for b in range(batch_size):
199
+ if pred_all and pred_all[-1][b].item() in [EOS_ID, PAD_ID]:
200
+ pred_step[b] = PAD_ID
201
+ if all([pred == PAD_ID for pred in pred_step]):
202
+ break
203
+ pred_step = torch.tensor(pred_step, device=tgt.device)
204
+ pred_all.append(pred_step)
205
+
206
+ if idx < self.max_seq_len - 1:
207
+ tgt_step = pred_step.unsqueeze(dim=0)
208
+ tgt = torch.cat([tgt, tgt_step], dim=0)
209
+
210
+ preds = torch.stack(pred_all)
211
+ return preds
212
+
213
+ @staticmethod
214
+ def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
215
+ r"""
216
+ Generate a square mask for the sequence. The masked positions are filled with
217
+ float('-inf').
218
+ Unmasked positions are filled with float(0.0).
219
+ """
220
+ return torch.triu(
221
+ torch.full((sz, sz), float("-inf"), device=device), diagonal=1
222
+ )
vocab.cond.vocab ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <pad> 0
2
+ <s> 0
3
+ </s> 0
4
+ <unk> 0
5
+ ▁<a> -1.261
6
+ ▁<c> -1.82603
7
+ <3> -2.17158
8
+ <0> -2.26491
9
+ <1> -2.34
10
+ <2> -2.36126
11
+ ▁<b> -2.89
12
+ <4> -3.31157
13
+ <5> -3.54753
14
+ <6> -5.02554
15
+ <7> -5.3972
16
+ <11> -5.51923
17
+ <8> -6.03597
18
+ <9> -6.14342
19
+ <10> -6.17248
20
+ <13> -6.28137
21
+ <12> -6.84099
22
+ <17> -7.69198
23
+ <14> -7.72356
24
+ <15> -8.15065
25
+ <20> -9.37115
26
+ <18> -9.52068
27
+ <19> -9.52068
28
+ <16> -9.55347
29
+ <23> -9.58737
30
+ <25> -9.95894
31
+ <24> -12.9547
vocab.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from argparse import ArgumentParser
3
+
4
+ import sentencepiece as spm
5
+ from transformers.models.bert.tokenization_bert import BertTokenizer
6
+
7
+ # Load the text tokenizer
8
+ tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
9
+
10
+ BOS_TOKEN = tokenizer.cls_token
11
+ EOS_TOKEN = tokenizer.sep_token
12
+ UNK_TOKEN = tokenizer.unk_token
13
+ PAD_ID = tokenizer.pad_token_id
14
+ BOS_ID = tokenizer.cls_token_id
15
+ EOS_ID = tokenizer.sep_token_id
16
+ UNK_ID = tokenizer.unk_token_id
17
+
18
+ # Build the condition (attribute) tokenizer
19
+ if __name__ == "__main__":
20
+ parser = ArgumentParser()
21
+ # fmt: off
22
+ parser.add_argument("--input", nargs="+", required=True)
23
+ parser.add_argument("--vocab-file", type=str, required=True)
24
+ parser.add_argument("--vocab-size", type=int, default=31)
25
+ parser.add_argument("--algo", type=str, default="bpe", choices=["bpe", "word"])
26
+ # fmt: on
27
+ args = parser.parse_args()
28
+ print("Building token vocabulary")
29
+ with tempfile.NamedTemporaryFile("w") as f:
30
+ # concatenate input files
31
+ for input_fname in args.input:
32
+ with open(input_fname) as input_f:
33
+ f.write(input_f.read() + "\n")
34
+ # run sentence piece with bpe
35
+ spm.SentencePieceTrainer.Train(
36
+ f"--add_dummy_prefix=false --pad_id=0 --bos_id=1 --eos_id=2 --unk_id=3 "
37
+ f"--vocab_size={args.vocab_size} "
38
+ f"--model_prefix={args.vocab_file} --model_type={args.algo} "
39
+ f"--input={f.name}"
40
+ )
41
+ ~