Sam Passaglia commited on
Commit
f29cbf8
1 Parent(s): ac462f6
config/heteronyms_Sato2022.json DELETED
@@ -1,211 +0,0 @@
1
- {
2
- "heteronyms_in_bert": {
3
- "表": 2,
4
- "角": 4,
5
- "大分": 2,
6
- "国立": 2,
7
- "人気": 3,
8
- "市場": 2,
9
- "気質": 2,
10
- "役所": 2,
11
- "上方": 2,
12
- "上手": 3,
13
- "下手": 3,
14
- "人事": 2,
15
- "金星": 2,
16
- "仮名": 2,
17
- "内面": 2,
18
- "礼拝": 2,
19
- "遺言": 3,
20
- "口腔": 2,
21
- "後世": 2,
22
- "骨": 2,
23
- "一途": 2,
24
- "一言": 3,
25
- "最中": 3,
26
- "一目": 2,
27
- "係": 3,
28
- "足跡": 2,
29
- "今日": 2,
30
- "明日": 3,
31
- "生物": 3,
32
- "変化": 2,
33
- "大事": 2,
34
- "水車": 2,
35
- "一見": 2,
36
- "一端": 2,
37
- "大家": 3,
38
- "心中": 2,
39
- "書物": 2,
40
- "一角": 2,
41
- "一行": 3,
42
- "一時": 3,
43
- "一定": 2,
44
- "一方": 2,
45
- "一夜": 2,
46
- "下野": 3,
47
- "化学": 2,
48
- "火口": 2,
49
- "花弁": 2,
50
- "玩具": 2,
51
- "強力": 3,
52
- "金色": 2,
53
- "経緯": 2,
54
- "故郷": 2,
55
- "紅葉": 2,
56
- "行方": 3,
57
- "根本": 2,
58
- "左右": 3,
59
- "山陰": 2,
60
- "十分": 2,
61
- "上下": 5,
62
- "身体": 2,
63
- "水面": 2,
64
- "世論": 2,
65
- "清水": 3,
66
- "大手": 2,
67
- "大人": 4,
68
- "大勢": 3,
69
- "中間": 5,
70
- "日向": 42,
71
- "日時": 3,
72
- "夫婦": 2,
73
- "牧場": 2,
74
- "末期": 2,
75
- "利益": 2,
76
- "工夫": 2,
77
- "一味": 2,
78
- "魚": 3,
79
- "区分": 2,
80
- "施行": 4,
81
- "施工": 2,
82
- "転生": 2,
83
- "博士": 2,
84
- "法華": 2,
85
- "真面目": 3,
86
- "眼鏡": 2,
87
- "文字": 2,
88
- "文書": 3,
89
- "律令": 2,
90
- "現世": 2,
91
- "日中": 2,
92
- "夜中": 3,
93
- "前世": 2,
94
- "二人": 2,
95
- "立像": 2
96
- },
97
- "heteronyms_not_in_bert": {
98
- "教化": 3,
99
- "見物": 2,
100
- "清浄": 2,
101
- "谷間": 2,
102
- "追従": 2,
103
- "墓石": 2,
104
- "大文字": 2,
105
- "漢書": 2,
106
- "作法": 2,
107
- "兵法": 2,
108
- "大人気": 2,
109
- "半月": 2,
110
- "黒子": 2,
111
- "外面": 2,
112
- "競売": 2,
113
- "開眼": 2,
114
- "求道": 2,
115
- "血脈": 2,
116
- "施業": 2,
117
- "借家": 2,
118
- "頭蓋骨": 2,
119
- "法衣": 2,
120
- "昨日": 2,
121
- "氷柱": 2,
122
- "風車": 2,
123
- "寒気": 2,
124
- "背筋": 2,
125
- "逆手": 2,
126
- "色紙": 2,
127
- "生花": 3,
128
- "白髪": 2,
129
- "貼付": 2,
130
- "一回": 2,
131
- "一期": 2,
132
- "一月": 3,
133
- "一所": 2,
134
- "一寸": 2,
135
- "一声": 2,
136
- "一石": 2,
137
- "一日": 4,
138
- "一分": 3,
139
- "一文": 3,
140
- "一片": 3,
141
- "何時": 3,
142
- "何分": 2,
143
- "火煙": 2,
144
- "火傷": 2,
145
- "火床": 3,
146
- "火先": 2,
147
- "火筒": 2,
148
- "芥子": 3,
149
- "気骨": 2,
150
- "銀杏": 3,
151
- "元金": 2,
152
- "五分": 2,
153
- "後々": 2,
154
- "後生": 2,
155
- "御供": 4,
156
- "細々": 3,
157
- "細目": 2,
158
- "三位": 2,
159
- "疾風": 3,
160
- "菖蒲": 2,
161
- "世人": 2,
162
- "世路": 2,
163
- "船底": 2,
164
- "早急": 2,
165
- "相乗": 2,
166
- "造作": 2,
167
- "他言": 2,
168
- "東雲": 2,
169
- "頭数": 2,
170
- "二重": 2,
171
- "日供": 2,
172
- "日次": 4,
173
- "日暮": 3,
174
- "日来": 3,
175
- "梅雨": 2,
176
- "風穴": 2,
177
- "仏語": 3,
178
- "分別": 2,
179
- "面子": 2,
180
- "木目": 2,
181
- "目下": 2,
182
- "夜直": 2,
183
- "夜来": 2,
184
- "夜話": 2,
185
- "野兎": 2,
186
- "野馬": 3,
187
- "野分": 2,
188
- "野辺": 2,
189
- "野面": 3,
190
- "野立": 3,
191
- "冷水": 2,
192
- "連中": 2,
193
- "飛沫": 2,
194
- "翡翠": 2,
195
- "餃子": 2,
196
- "一足": 2,
197
- "意気地": 2,
198
- "一昨日": 3,
199
- "一昨年": 2,
200
- "十八番": 2,
201
- "十六夜": 2,
202
- "明後日": 2,
203
- "石綿": 2,
204
- "公文": 2,
205
- "読本": 3,
206
- "仏国": 3,
207
- "古本": 2,
208
- "町家": 2,
209
- "遊行": 2
210
- }
211
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/__init__.py DELETED
File without changes
yomikata/dataset/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (163 Bytes)
yomikata/dataset/__pycache__/aozora.cpython-310.pyc DELETED
Binary file (2.99 kB)
yomikata/dataset/__pycache__/bccwj.cpython-310.pyc DELETED
Binary file (5.31 kB)
yomikata/dataset/__pycache__/kwdlc.cpython-310.pyc DELETED
Binary file (2.47 kB)
yomikata/dataset/__pycache__/ndlbib.cpython-310.pyc DELETED
Binary file (1.3 kB)
yomikata/dataset/__pycache__/pronunciations.cpython-310.pyc DELETED
Binary file (1.44 kB)
yomikata/dataset/__pycache__/repair_long_vowels.cpython-310.pyc DELETED
Binary file (2.13 kB)
yomikata/dataset/__pycache__/split.cpython-310.pyc DELETED
Binary file (8.08 kB)
yomikata/dataset/__pycache__/sudachi.cpython-310.pyc DELETED
Binary file (1.15 kB)
yomikata/dataset/__pycache__/unidic.cpython-310.pyc DELETED
Binary file (1.27 kB)
yomikata/dataset/aozora.py DELETED
@@ -1,117 +0,0 @@
1
- """aozora.py
2
- Data processing script for aozora bunko file from https://github.com/ndl-lab/huriganacorpus-aozora
3
- """
4
-
5
- import warnings
6
- from pathlib import Path
7
-
8
- import pandas as pd
9
- from pandas.errors import ParserError
10
- from speach import ttlig
11
-
12
- from config import config
13
- from config.config import logger
14
- from yomikata import utils
15
- from yomikata.dataset.repair_long_vowels import repair_long_vowels
16
-
17
- warnings.filterwarnings("ignore")
18
-
19
-
20
- def read_file(file: str):
21
- # logger.info("reading file")
22
- with open(file) as f:
23
- rows = [
24
- line.rstrip("\n").rstrip("\r").split("\t")[0:3] for line in f.readlines()
25
- ]
26
- df = pd.DataFrame(rows, columns=["word", "furigana", "type"])
27
-
28
- # logger.info("removing unused rows")
29
- # remove unused rows
30
- df = df[~df["type"].isin(["[入力 読み]", "分かち書き"])]
31
- df = df[~pd.isna(df["word"])]
32
- df = df[~pd.isnull(df["word"])]
33
- df = df[df["word"] != ""]
34
-
35
- # logger.info("organizing into sentences")
36
- # now organize remaining rows into sentences
37
- gyou_df = pd.DataFrame(columns=["sentence", "furigana", "sentenceid"])
38
- sentence = ""
39
- furigana = ""
40
- sentenceid = None
41
- gyous = []
42
- for row in df.itertuples():
43
- if row.type in ["[入力文]"]:
44
- sentence = row.word
45
- elif row.type in ["漢字"]:
46
- furigana += ttlig.RubyToken.from_furi(
47
- row.word, repair_long_vowels(row.furigana, row.word)
48
- ).to_code()
49
- elif row.word.split(":")[0] in ["行番号"]:
50
- if sentenceid: # this handles the first row
51
- gyous.append([sentence, furigana, sentenceid])
52
- sentenceid = file.name + "_" + row.word.split(":")[1].strip()
53
- sentence = None
54
- furigana = ""
55
- else:
56
- furigana += row.word
57
-
58
- # last row handling
59
- gyous.append([sentence, furigana, sentenceid])
60
-
61
- # make dataframe
62
- gyou_df = pd.DataFrame(gyous, columns=["sentence", "furigana", "sentenceid"])
63
- gyou_df = gyou_df[~pd.isna(gyou_df.sentence)]
64
-
65
- # logger.info("cleaning rows")
66
- # clean rows
67
- gyou_df["furigana"] = gyou_df["furigana"].apply(utils.standardize_text)
68
- gyou_df["sentence"] = gyou_df["sentence"].apply(
69
- lambda s: utils.standardize_text(
70
- s.replace("|", "").replace(" ", "").replace("※", "")
71
- )
72
- )
73
-
74
- # logger.info("removing errors")
75
- # remove non-matching rows
76
- gyou_df = gyou_df[
77
- gyou_df["sentence"] == gyou_df["furigana"].apply(utils.remove_furigana)
78
- ]
79
-
80
- # remove known errors
81
- error_ids = []
82
- gyou_df = gyou_df[~gyou_df["sentenceid"].isin(error_ids)]
83
-
84
- # remove duplicates
85
- gyou_df = gyou_df.drop_duplicates()
86
-
87
- return gyou_df
88
-
89
-
90
- def aozora_data():
91
- """Extract, load and transform the aozora data"""
92
-
93
- # Extract sentences from the data files
94
- files = list(Path(config.RAW_DATA_DIR, "aozora").glob("*/*/*.txt"))
95
-
96
- with open(Path(config.SENTENCE_DATA_DIR, "aozora.csv"), "w") as f:
97
- f.write("sentence,furigana,sentenceid\n")
98
-
99
- for i, file in enumerate(files):
100
- logger.info(f"{i+1}/{len(files)} {file.name}")
101
- try:
102
- df = read_file(file)
103
- except ParserError:
104
- logger.error(f"Parser error on {file}")
105
-
106
- df.to_csv(
107
- Path(config.SENTENCE_DATA_DIR, "aozora.csv"),
108
- mode="a",
109
- index=False,
110
- header=False,
111
- )
112
-
113
- logger.info("✅ Saved all aozora data!")
114
-
115
-
116
- if __name__ == "__main__":
117
- aozora_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/bccwj.py DELETED
@@ -1,206 +0,0 @@
1
- """bccwj.py
2
- Data processing script for files downloaded from Chuunagon search
3
- Chuunagon URL: https://chunagon.ninjal.ac.jp/
4
-
5
- Download with the settings
6
- 文脈中の区切り記号 |
7
- 文脈中の文区切り記号 #
8
- 前後文脈の語数 10
9
- 検索対象(固定長・可変長) 両方
10
- 共起条件の範囲 文境界をまたがない
11
-
12
- ダウンロードオプション
13
- システム Linux
14
- 文字コード UTF-8
15
- 改行コード LF
16
- 出力ファイルが一つの場合は Zip 圧縮を行わない 検索条件式ごとに出力ファイルを分割する
17
- インラインタグを使用 CHECK BOTH 語彙素読み AND 発音形出現形語種 BOX
18
- (発音形出現形 is the actual pronounced one, but displays e.g. よう れい as よー れー)
19
- タグの区切り記号 :
20
- """
21
-
22
- import warnings
23
- from pathlib import Path
24
-
25
- import jaconv
26
- import pandas as pd
27
- from speach.ttlig import RubyToken
28
-
29
- from config import config
30
- from config.config import logger
31
- from yomikata import utils
32
-
33
- warnings.filterwarnings("ignore")
34
-
35
- SENTENCE_SPLIT_CHAR = "#"
36
- WORD_SPLIT_CHAR = "|"
37
- READING_SEP_CHAR = ":"
38
-
39
-
40
- def read_bccwj_file(filename: str):
41
- """ """
42
-
43
- df = pd.read_csv(filename, sep="\t")
44
-
45
- df["前文脈"] = df["前文脈"].fillna("")
46
- df["後文脈"] = df["後文脈"].fillna("")
47
- df["full_text"] = (
48
- df["前文脈"] + df["キー"] + "[" + df["語彙素読み"] + ":" + df["発音形出現形"] + "]" + df["後文脈"]
49
- )
50
-
51
- def get_sentences(row):
52
- sentences = row["full_text"].split(SENTENCE_SPLIT_CHAR)
53
- furigana_sentences = []
54
- for sentence in sentences:
55
- words_with_readings = sentence.split(WORD_SPLIT_CHAR)
56
- furigana_sentence = ""
57
- for word_with_reading in words_with_readings:
58
- word = word_with_reading.split("[")[0]
59
- form, reading = jaconv.kata2hira(
60
- word_with_reading.split("[")[1].split("]")[0]
61
- ).split(READING_SEP_CHAR)
62
-
63
- if (
64
- not utils.has_kanji(word)
65
- or reading == jaconv.kata2hira(word)
66
- or form == ""
67
- or reading == ""
68
- ):
69
- furigana_sentence += word
70
- else:
71
- if ("ー" in reading) and ("ー" not in form):
72
- indexes_of_dash = [
73
- pos for pos, char in enumerate(reading) if char == "ー"
74
- ]
75
- for index_of_dash in indexes_of_dash:
76
- if len(reading) == len(form):
77
- dash_reading = form[index_of_dash]
78
- else:
79
- char_before_dash = reading[index_of_dash - 1]
80
- if char_before_dash in "ねめせぜれてでけげへべぺ":
81
- digraphA = char_before_dash + "え"
82
- digraphB = char_before_dash + "い"
83
- if digraphA in form and digraphB not in form:
84
- dash_reading = "え"
85
- elif digraphB in form and digraphA not in form:
86
- dash_reading = "い"
87
- else:
88
- logger.warning(
89
- f"Leaving dash in {word} {form} {reading}"
90
- )
91
- dash_reading = "ー"
92
- elif char_before_dash in "ぬつづむるくぐすずゆゅふぶぷ":
93
- dash_reading = "う"
94
- elif char_before_dash in "しじみいきぎひびち":
95
- dash_reading = "い"
96
- elif char_before_dash in "そぞのこごもろとどよょおほぼぽ":
97
- digraphA = char_before_dash + "お"
98
- digraphB = char_before_dash + "う"
99
- if digraphA in form and digraphB not in form:
100
- dash_reading = "お"
101
- elif digraphB in form and digraphA not in form:
102
- dash_reading = "う"
103
- else:
104
- if digraphA in word and digraphB not in word:
105
- dash_reading = "お"
106
- elif digraphB in word and digraphA not in word:
107
- dash_reading = "う"
108
- else:
109
- logger.warning(
110
- f"Leaving dash in {word} {form} {reading}"
111
- )
112
- dash_reading = "ー"
113
- else:
114
- logger.warning(
115
- f"Leaving dash in {word} {form} {reading}"
116
- )
117
- dash_reading = "ー"
118
- reading = (
119
- reading[:index_of_dash]
120
- + dash_reading
121
- + reading[index_of_dash + 1 :]
122
- )
123
- furigana_sentence += RubyToken.from_furi(word, reading).to_code()
124
-
125
- furigana_sentences.append(furigana_sentence)
126
-
127
- furigana_sentences = [
128
- utils.standardize_text(sentence) for sentence in furigana_sentences
129
- ]
130
- sentences = [utils.remove_furigana(sentence) for sentence in furigana_sentences]
131
- try:
132
- rowid = row["サンプル ID"]
133
- except KeyError:
134
- rowid = row["講演 ID"]
135
- if len(furigana_sentences) == 1:
136
- ids = [rowid]
137
- else:
138
- ids = [rowid + "_" + str(i) for i in range(len(furigana_sentences))]
139
-
140
- sub_df = pd.DataFrame(
141
- {"sentence": sentences, "furigana": furigana_sentences, "sentenceid": ids}
142
- )
143
-
144
- sub_df = sub_df[sub_df["sentence"] != sub_df["furigana"]]
145
-
146
- return sub_df
147
-
148
- output_df = pd.DataFrame()
149
- for i, row in df.iterrows():
150
- output_df = output_df.append(get_sentences(row))
151
-
152
- return output_df
153
-
154
-
155
- def bccwj_data():
156
- """Extract, load and transform the bccwj data"""
157
-
158
- # Extract sentences from the data files
159
- bccwj_files = list(Path(config.RAW_DATA_DIR, "bccwj").glob("*.txt"))
160
-
161
- df = pd.DataFrame()
162
-
163
- for bccwj_file in bccwj_files:
164
- logger.info(bccwj_file.name)
165
- df = pd.concat([df, read_bccwj_file(bccwj_file)])
166
-
167
- # remove known errors
168
- error_ids = []
169
-
170
- df = df[~df["sentenceid"].isin(error_ids)]
171
- df = df[df["sentence"] != ""]
172
- df = df.drop_duplicates()
173
- df["furigana"] = df["furigana"].apply(utils.standardize_text)
174
- df["sentence"] = df["sentence"].apply(utils.standardize_text)
175
- assert (df["sentence"] == df["furigana"].apply(utils.remove_furigana)).all()
176
-
177
- # Output
178
- df.to_csv(Path(config.SENTENCE_DATA_DIR, "bccwj.csv"), index=False)
179
-
180
- logger.info("✅ Saved bccwj data!")
181
-
182
-
183
- def bccwj_subset(bccwj_file):
184
- """Extract, load and transform a subset of the bccwj data"""
185
-
186
- df = read_bccwj_file(bccwj_file)
187
-
188
- # remove known errors
189
- error_ids = []
190
-
191
- df = df[~df["sentenceid"].isin(error_ids)]
192
- df = df.drop_duplicates()
193
- df["furigana"] = df["furigana"].apply(utils.standardize_text)
194
- df["sentence"] = df["sentence"].apply(utils.standardize_text)
195
-
196
- # Output
197
- df.to_csv(
198
- Path(config.SENTENCE_DATA_DIR, bccwj_file.name.split(".")[0] + ".csv"),
199
- index=False,
200
- )
201
-
202
- logger.info("✅ Saved bccwj " + bccwj_file.name.split(".")[0] + " data!")
203
-
204
-
205
- if __name__ == "__main__":
206
- bccwj_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/kwdlc.py DELETED
@@ -1,109 +0,0 @@
1
- """kwdlc.py
2
- Data processing script for KWDLC files directly in the repository format
3
- KWDLC repository: https://github.com/ku-nlp/KWDLC
4
- """
5
-
6
- import warnings
7
- from pathlib import Path
8
-
9
- import pandas as pd
10
- from speach import ttlig
11
-
12
- from config import config
13
- from config.config import logger
14
- from yomikata import utils
15
-
16
- warnings.filterwarnings("ignore")
17
-
18
-
19
- def read_knp_file(filename: str):
20
- with open(filename) as f:
21
- contents = f.readlines()
22
-
23
- ids = []
24
- sentences = []
25
- furiganas = []
26
- sentence = ""
27
- furigana = ""
28
- for row in contents:
29
- first_word = row.split(" ")[0]
30
- if first_word in ["*", "+"]:
31
- pass
32
- elif first_word == "#":
33
- sentence_id = row.split(" ")[1].split("S-ID:")[1]
34
- elif first_word == "EOS\n":
35
- sentence = utils.standardize_text(sentence)
36
- furigana = utils.standardize_text(furigana)
37
- if sentence == utils.remove_furigana(furigana):
38
- sentences.append(sentence)
39
- furiganas.append(furigana)
40
- ids.append(sentence_id)
41
- else:
42
- logger.warning(
43
- f"Dropping mismatched line \n Sentence: {sentence} \n Furigana: {furigana}"
44
- )
45
- sentence = ""
46
- furigana = ""
47
- else:
48
- words = row.split(" ")
49
- sentence += words[0]
50
- if words[0] == words[1]:
51
- furigana += words[0]
52
- else:
53
- furigana += ttlig.RubyToken.from_furi(words[0], words[1]).to_code()
54
-
55
- assert len(ids) == len(sentences)
56
- assert len(sentences) == len(furiganas)
57
- return ids, sentences, furiganas # readings
58
-
59
-
60
- def kwdlc_data():
61
- """Extract, load and transform the kwdlc data"""
62
-
63
- # Extract sentences from the data files
64
- knp_files = list(Path(config.RAW_DATA_DIR, "kwdlc").glob("**/*.knp"))
65
-
66
- all_ids = []
67
- all_sentences = []
68
- all_furiganas = []
69
- for knp_file in knp_files:
70
- ids, sentences, furiganas = read_knp_file(knp_file)
71
- all_ids += ids
72
- all_sentences += sentences
73
- all_furiganas += furiganas
74
-
75
- # construct dataframe
76
- df = pd.DataFrame(
77
- list(
78
- zip(all_sentences, all_furiganas, all_ids)
79
- ), # all_readings, all_furiganas)),
80
- columns=["sentence", "furigana", "sentenceid"],
81
- )
82
-
83
- # remove known errors
84
- error_ids = [
85
- "w201106-0000547376-1",
86
- "w201106-0001768070-1-01",
87
- "w201106-0000785999-1",
88
- "w201106-0001500842-1",
89
- "w201106-0000704257-1",
90
- "w201106-0002300346-3",
91
- "w201106-0001779669-3",
92
- "w201106-0000259203-1",
93
- ]
94
-
95
- df = df[~df["sentenceid"].isin(error_ids)]
96
- df = df.drop_duplicates()
97
- df["furigana"] = df["furigana"].apply(utils.standardize_text)
98
- df["sentence"] = df["sentence"].apply(utils.standardize_text)
99
- # Test
100
- assert (df["sentence"] == df["furigana"].apply(utils.remove_furigana)).all()
101
-
102
- # Output
103
- df.to_csv(Path(config.SENTENCE_DATA_DIR, "kwdlc.csv"), index=False)
104
-
105
- logger.info("✅ Saved kwdlc data!")
106
-
107
-
108
- if __name__ == "__main__":
109
- kwdlc_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/ndlbib.py DELETED
@@ -1,46 +0,0 @@
1
- """ndlbib.py
2
- Data processing script for ndlbib sentence file from https://github.com/ndl-lab/huriganacorpus-ndlbib
3
- """
4
-
5
- import warnings
6
- from pathlib import Path
7
-
8
- from pandas.errors import ParserError
9
-
10
- from config import config
11
- from config.config import logger
12
- from yomikata.dataset.aozora import read_file
13
-
14
- # ndlbib and aozora use same file structure
15
-
16
- warnings.filterwarnings("ignore")
17
-
18
-
19
- def ndlbib_data():
20
- """Extract, load and transform the ndlbib data"""
21
-
22
- # Extract sentences from the data files
23
- files = list(Path(config.RAW_DATA_DIR, "shosi").glob("*.txt"))
24
-
25
- with open(Path(config.SENTENCE_DATA_DIR, "ndlbib.csv"), "w") as f:
26
- f.write("sentence,furigana,sentenceid\n")
27
-
28
- for i, file in enumerate(files):
29
- logger.info(f"{i+1}/{len(files)} {file.name}")
30
- try:
31
- df = read_file(file)
32
- except ParserError:
33
- logger.error(f"Parser error on {file}")
34
-
35
- df.to_csv(
36
- Path(config.SENTENCE_DATA_DIR, "ndlbib.csv"),
37
- mode="a",
38
- index=False,
39
- header=False,
40
- )
41
-
42
- logger.info("✅ Saved ndlbib data!")
43
-
44
-
45
- if __name__ == "__main__":
46
- ndlbib_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/pronunciations.py DELETED
@@ -1,57 +0,0 @@
1
- from pathlib import Path
2
-
3
- import jaconv
4
- import pandas as pd
5
- from tqdm import tqdm
6
-
7
- from config import config
8
- from config.config import logger
9
- from yomikata import utils
10
-
11
-
12
- def pronunciation_data():
13
-
14
- data_files = list(Path(config.READING_DATA_DIR).glob("*.csv"))
15
-
16
- df = pd.DataFrame()
17
-
18
- for file in data_files:
19
- if (file.name == "all.csv") or (file.name == "ambiguous.csv"):
20
- continue
21
- output_df = pd.read_csv(file)
22
- df = pd.concat([df, output_df])
23
-
24
- df["surface"] = df["surface"].astype(str).str.strip()
25
- df["kana"] = df["kana"].astype(str).str.strip()
26
-
27
- tqdm.pandas()
28
-
29
- df["kana"] = df["kana"].progress_apply(utils.standardize_text)
30
- df["surface"] = df["surface"].progress_apply(utils.standardize_text)
31
- df["kana"] = df.progress_apply(lambda row: jaconv.kata2hira(row["kana"]), axis=1)
32
- df = df[df["surface"] != df["kana"]]
33
- df = df[df["kana"] != ""]
34
-
35
- df = df[df["surface"].progress_apply(utils.has_kanji)]
36
-
37
- df = df.loc[~df["surface"].str.contains(r"[〜〜()\)\(\*]\.")]
38
-
39
- df = df[["surface", "kana"]]
40
- df = df.drop_duplicates()
41
-
42
- df.to_csv(Path(config.READING_DATA_DIR, "all.csv"), index=False)
43
-
44
- logger.info("✅ Merged all the pronunciation data!")
45
-
46
- # merged_df = (
47
- # df.groupby("surface")["kana"]
48
- # .apply(list)
49
- # .reset_index(name="pronunciations")
50
- # )
51
-
52
- # ambiguous_df = merged_df[merged_df["pronunciations"].apply(len) > 1]
53
- # ambiguous_df.to_csv(Path(config.READING_DATA_DIR, "ambiguous.csv"), index=False)
54
-
55
-
56
- if __name__ == "__main__":
57
- pronunciation_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/repair_long_vowels.py DELETED
@@ -1,62 +0,0 @@
1
- from pathlib import Path
2
-
3
- import pandas as pd
4
-
5
- from config import config
6
- from config.config import logger
7
-
8
- pronunciation_df = pd.read_csv(Path(config.READING_DATA_DIR, "all.csv"))
9
- pronunciation_df = pronunciation_df.groupby("surface")["kana"].apply(list)
10
-
11
-
12
- def repair_long_vowels(kana: str, kanji: str = None) -> str:
13
- """Clean and normalize text
14
-
15
- Args:
16
- kana (str): input string
17
- kanji (str): input string, optional
18
-
19
- Returns:
20
- str: a cleaned string
21
- """
22
-
23
- reading = kana
24
- indices_of_dash = [pos for pos, char in enumerate(reading) if char == "ー"]
25
-
26
- # get rid of non-ambiguous dashes
27
- for index_of_dash in indices_of_dash:
28
- char_before_dash = reading[index_of_dash - 1]
29
- if char_before_dash in "ぬつづむるくぐすずゆゅふぶぷ":
30
- reading = reading[:index_of_dash] + "う" + reading[index_of_dash + 1 :]
31
- elif char_before_dash in "しじみいきぎひびちぢぃ":
32
- reading = reading[:index_of_dash] + "い" + reading[index_of_dash + 1 :]
33
-
34
- indices_of_not_dash = [pos for pos, char in enumerate(reading) if char != "ー"]
35
- if len(indices_of_not_dash) != len(reading):
36
- if not kanji:
37
- logger.info("Disambiguating this dash requires kanji")
38
- logger.info(f"Left dash in {reading}")
39
- else:
40
- try:
41
- candidate_pronunciations = list(pronunciation_df[kanji])
42
- except KeyError:
43
- candidate_pronunciations = []
44
-
45
- candidate_pronunciations = list(set(candidate_pronunciations))
46
-
47
- candidate_pronunciations = [
48
- x for x in candidate_pronunciations if len(x) == len(reading)
49
- ]
50
- candidate_pronunciations = [
51
- x
52
- for x in candidate_pronunciations
53
- if all([x[i] == reading[i] for i in indices_of_not_dash])
54
- ]
55
-
56
- if len(candidate_pronunciations) == 1:
57
- reading = candidate_pronunciations[0]
58
- else:
59
- pass
60
- # logger.warning(f"Left dashes in {kanji} {reading}")
61
-
62
- return reading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/split.py DELETED
@@ -1,271 +0,0 @@
1
- from pathlib import Path
2
-
3
- import pandas as pd
4
- from sklearn.model_selection import train_test_split
5
- from speach.ttlig import RubyFrag, RubyToken
6
-
7
- from config import config
8
- from config.config import logger
9
- from yomikata import utils
10
- from yomikata.dictionary import Dictionary
11
-
12
-
13
- def train_val_test_split(X, y, train_size, val_size, test_size):
14
- """Split dataset into data splits."""
15
- assert (train_size + val_size + test_size) == 1
16
- X_train, X_, y_train, y_ = train_test_split(X, y, train_size=train_size)
17
- X_val, X_test, y_val, y_test = train_test_split(
18
- X_, y_, train_size=val_size / (test_size + val_size)
19
- )
20
- return X_train, X_val, X_test, y_train, y_val, y_test
21
-
22
-
23
- def filter_simple(input_file, output_file, heteronyms) -> None:
24
- """This filters out sentences which don't contain any heteronyms"""
25
-
26
- df = pd.read_csv(input_file) # load
27
- logger.info(f"Prefilter size: {len(df)}")
28
-
29
- df = df[df["sentence"].str.contains(r"|".join(heteronyms))]
30
- logger.info(f"Postfilter size: {len(df)}")
31
-
32
- df.to_csv(output_file, index=False)
33
-
34
-
35
- def filter_dictionary(input_file, output_file, heteronyms, dictionary) -> None:
36
- """This filters out sentences which contain heteronyms only as part of a compound which is known to the dictionary"""
37
- df = pd.read_csv(input_file) # load
38
- logger.info(f"Prefilter size: {len(df)}")
39
-
40
- df["contains_heteronym"] = df["sentence"].apply(
41
- lambda s: not set(
42
- [dictionary.token_to_surface(m) for m in dictionary.tagger(s)]
43
- ).isdisjoint(heteronyms)
44
- )
45
-
46
- df = df[df["contains_heteronym"]]
47
- logger.info(f"Postfilter size: {len(df)}")
48
-
49
- df.to_csv(output_file, index=False)
50
-
51
-
52
- def regroup_furigana(s, heteronym, heteronym_dict, dictionary, verbose=False):
53
- rubytokens = utils.parse_furigana(s)
54
- output_tokens = []
55
- for token in rubytokens.groups:
56
- if isinstance(token, RubyFrag):
57
- # this is a token with furigana
58
- if heteronym in token.text and token.text != heteronym:
59
- # it includes the heteronym but is not exactly the heteronym
60
- # if len(dictionary.tagger(token.text)) > 1:
61
- # it is not in the dictionary, so we try to regroup it
62
- # note this dictionary check is not foolproof: sometimes words are in the dictionary and found here,
63
- # but in a parse of the whole sentence the word will be split in two.
64
- # commented this out since actually even if it is part of dictionary, it will go through the training and so we might as well try to regroup it to avoid it being an <OTHER>
65
- viable_regroupings = []
66
- for reading in heteronym_dict[heteronym]:
67
- regrouped_tokens = regroup_furigana_tokens(
68
- [token], heteronym, reading, verbose=verbose
69
- )
70
- if regrouped_tokens != [token]:
71
- if verbose:
72
- print("viable regrouping found")
73
- viable_regroupings.append(regrouped_tokens)
74
- if len(viable_regroupings) == 1:
75
- output_tokens += viable_regroupings[0]
76
- continue
77
- else:
78
- if verbose:
79
- print("multiple viable readings found, cannot regroup")
80
- pass
81
- output_tokens.append(token)
82
-
83
- output_string = RubyToken(groups=output_tokens).to_code()
84
- assert utils.furigana_to_kana(output_string) == utils.furigana_to_kana(s)
85
- assert utils.remove_furigana(output_string) == utils.remove_furigana(s)
86
- return output_string
87
-
88
-
89
- def regroup_furigana_tokens(ruby_tokens, heteronym, reading, verbose=False):
90
- if not len(ruby_tokens) == 1:
91
- raise ValueError("regroup failed, no support yet for token merging")
92
- ruby_token = ruby_tokens[0]
93
- text = ruby_token.text
94
- furi = ruby_token.furi
95
- try:
96
- split_text = [
97
- text[0 : text.index(heteronym)],
98
- heteronym,
99
- text[text.index(heteronym) + len(heteronym) :],
100
- ]
101
- split_text = [text for text in split_text if text != ""]
102
- except ValueError:
103
- if verbose:
104
- print("regroup failed, heteronym not in token text")
105
- return ruby_tokens
106
-
107
- try:
108
- split_furi = [
109
- furi[0 : furi.index(reading)],
110
- reading,
111
- furi[furi.index(reading) + len(reading) :],
112
- ]
113
- split_furi = [furi for furi in split_furi if furi != ""]
114
- except ValueError:
115
- if verbose:
116
- print("regroup failed, reading not in token furi")
117
- return ruby_tokens
118
-
119
- if not len(split_text) == len(split_furi):
120
- if verbose:
121
- print(
122
- "regroup failed, failed to find heteronym and its reading in the same place in the inputs"
123
- )
124
- return ruby_tokens
125
-
126
- regrouped_tokens = [
127
- RubyFrag(text=split_text[i], furi=split_furi[i]) for i in range(len(split_text))
128
- ]
129
-
130
- if not "".join([token.furi for token in ruby_tokens]) == "".join(
131
- [token.furi for token in regrouped_tokens]
132
- ):
133
- if verbose:
134
- print(
135
- "regroup failed, reading of produced result does not agree with reading of input"
136
- )
137
- return ruby_tokens
138
- if not [token.furi for token in regrouped_tokens if token.text == heteronym] == [
139
- reading
140
- ]:
141
- if verbose:
142
- print("regroup failed, the heteronym did not get assigned the reading")
143
- return ruby_tokens
144
- return regrouped_tokens
145
-
146
-
147
- def optimize_furigana(input_file, output_file, heteronym_dict, dictionary) -> None:
148
- df = pd.read_csv(input_file) # load
149
- logger.info("Optimizing furigana using heteronym list and dictionary")
150
- for heteronym in heteronym_dict.keys():
151
- logger.info(f"Heteronym {heteronym} {heteronym_dict[heteronym]}")
152
- n_with_het = sum(df["sentence"].str.contains(heteronym))
153
- rows_to_rearrange = df["sentence"].str.contains(heteronym)
154
- optimized_rows = df.loc[rows_to_rearrange, "furigana"].apply(
155
- lambda s: regroup_furigana(s, heteronym, heteronym_dict, dictionary)
156
- )
157
- n_rearranged = sum(df.loc[rows_to_rearrange, "furigana"] != optimized_rows)
158
- logger.info(f"{n_rearranged}/{n_with_het} sentences were optimized")
159
- df.loc[rows_to_rearrange, "furigana"] = optimized_rows
160
- df.to_csv(output_file, index=False)
161
-
162
-
163
- def remove_other_readings(input_file, output_file, heteronym_dict):
164
- df = pd.read_csv(input_file) # load
165
- logger.info(f"Prefilter size: {len(df)}")
166
- df["keep_row"] = False
167
- for heteronym in heteronym_dict.keys():
168
- logger.info(heteronym)
169
- n_with_het = sum(df["sentence"].str.contains(heteronym))
170
- keep_for_het = df["furigana"].str.contains(
171
- r"|".join(
172
- [f"{{{heteronym}/{reading}}}" for reading in heteronym_dict[heteronym]]
173
- )
174
- )
175
- df["keep_row"] = df["keep_row"] | keep_for_het
176
- logger.info(
177
- f"Dropped {n_with_het-sum(keep_for_het)}/{n_with_het} sentences which have different readings"
178
- ) # TODO reword
179
- df = df.loc[df["keep_row"]]
180
- df = df.drop("keep_row", axis=1)
181
- df.to_csv(output_file, index=False)
182
-
183
-
184
- def check_data(input_file) -> bool:
185
-
186
- df = pd.read_csv(input_file) # load
187
- df["furigana-test"] = df["sentence"] == df["furigana"].apply(utils.remove_furigana)
188
- assert df["furigana-test"].all()
189
- df["sentence-standardize-test"] = df["sentence"] == df["sentence"].apply(
190
- utils.standardize_text
191
- )
192
- assert df["sentence-standardize-test"].all()
193
-
194
- return True
195
-
196
-
197
- def split_data(data_file) -> None:
198
-
199
- df = pd.read_csv(data_file) # load
200
-
201
- X = df["sentence"].values
202
- y = df["furigana"].values
203
-
204
- (X_train, X_val, X_test, y_train, y_val, y_test) = train_val_test_split(
205
- X=X,
206
- y=y,
207
- train_size=config.TRAIN_SIZE,
208
- val_size=config.VAL_SIZE,
209
- test_size=config.TEST_SIZE,
210
- )
211
-
212
- train_df = pd.DataFrame({"sentence": X_train, "furigana": y_train})
213
- val_df = pd.DataFrame({"sentence": X_val, "furigana": y_val})
214
- test_df = pd.DataFrame({"sentence": X_test, "furigana": y_test})
215
-
216
- train_df.to_csv(Path(config.TRAIN_DATA_DIR, "train_" + data_file.name), index=False)
217
- val_df.to_csv(Path(config.VAL_DATA_DIR, "val_" + data_file.name), index=False)
218
- test_df.to_csv(Path(config.TEST_DATA_DIR, "test_" + data_file.name), index=False)
219
-
220
-
221
- if __name__ == "__main__":
222
-
223
- input_files = [
224
- Path(config.SENTENCE_DATA_DIR, "aozora.csv"),
225
- Path(config.SENTENCE_DATA_DIR, "kwdlc.csv"),
226
- Path(config.SENTENCE_DATA_DIR, "bccwj.csv"),
227
- Path(config.SENTENCE_DATA_DIR, "ndlbib.csv"),
228
- ]
229
-
230
- logger.info("Merging sentence data")
231
- utils.merge_csvs(input_files, Path(config.SENTENCE_DATA_DIR, "all.csv"), n_header=1)
232
-
233
- logger.info("Rough filtering for sentences with heteronyms")
234
- filter_simple(
235
- Path(config.SENTENCE_DATA_DIR, "all.csv"),
236
- Path(config.SENTENCE_DATA_DIR, "have_heteronyms_simple.csv"),
237
- config.HETERONYMS.keys(),
238
- )
239
-
240
- logger.info("Sudachidict filtering for out heteronyms in known compounds")
241
- filter_dictionary(
242
- Path(config.SENTENCE_DATA_DIR, "have_heteronyms_simple.csv"),
243
- Path(config.SENTENCE_DATA_DIR, "have_heteronyms.csv"),
244
- config.HETERONYMS.keys(),
245
- Dictionary("sudachi"),
246
- )
247
-
248
- logger.info("Optimizing furigana")
249
- optimize_furigana(
250
- Path(config.SENTENCE_DATA_DIR, "have_heteronyms.csv"),
251
- Path(config.SENTENCE_DATA_DIR, "optimized_heteronyms.csv"),
252
- config.HETERONYMS,
253
- Dictionary("sudachi"),
254
- )
255
-
256
- logger.info("Removing heteronyms with unexpected readings")
257
- remove_other_readings(
258
- Path(config.SENTENCE_DATA_DIR, "optimized_heteronyms.csv"),
259
- Path(config.SENTENCE_DATA_DIR, "optimized_strict_heteronyms.csv"),
260
- config.HETERONYMS,
261
- )
262
-
263
- logger.info("Running checks on data")
264
- test_result = check_data(
265
- Path(config.SENTENCE_DATA_DIR, "optimized_strict_heteronyms.csv")
266
- )
267
-
268
- logger.info("Performing train/test/split")
269
- split_data(Path(config.SENTENCE_DATA_DIR, "optimized_strict_heteronyms.csv"))
270
-
271
- logger.info("Data splits successfully generated!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/sudachi.py DELETED
@@ -1,50 +0,0 @@
1
- """sudachi.py
2
- Data processing script for sudachi dictionary
3
- """
4
-
5
- import warnings
6
- from pathlib import Path
7
-
8
- import pandas as pd
9
-
10
- from config import config
11
- from config.config import logger
12
-
13
- warnings.filterwarnings("ignore")
14
-
15
-
16
- def sudachi_data():
17
-
18
- sudachi_file = list(Path(config.RAW_DATA_DIR, "sudachi").glob("*.csv"))
19
-
20
- df = pd.DataFrame()
21
-
22
- for file in sudachi_file:
23
- logger.info(file.name)
24
- # Load file
25
- df = pd.concat(
26
- [
27
- df,
28
- pd.read_csv(
29
- file,
30
- header=None,
31
- ),
32
- ]
33
- )
34
-
35
- df["surface"] = df[0].astype(str).str.strip()
36
- df["kana"] = df[11].astype(str).str.strip()
37
- df["type"] = df[5].astype(str).str.strip()
38
- df = df[df["kana"] != "*"]
39
- df = df[df["surface"] != df["kana"]]
40
- df = df[df["type"] != "補助記号"]
41
-
42
- df = df[["surface", "kana"]]
43
-
44
- df.to_csv(Path(config.READING_DATA_DIR, "sudachi.csv"), index=False)
45
-
46
- logger.info("✅ Processed sudachi data!")
47
-
48
-
49
- if __name__ == "__main__":
50
- sudachi_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/dataset/unidic.py DELETED
@@ -1,44 +0,0 @@
1
- """unidic.py
2
- Data processing script for unidic dictionary
3
- """
4
-
5
- import warnings
6
- from pathlib import Path
7
-
8
- import pandas as pd
9
-
10
- from config import config
11
- from config.config import logger
12
-
13
- warnings.filterwarnings("ignore")
14
-
15
-
16
- def unidic_data():
17
- """Extract, load and transform the unidic data"""
18
-
19
- # Extract sentences from the data files
20
- unidic_file = list(Path(config.RAW_DATA_DIR, "unidic").glob("*.csv"))[0]
21
-
22
- # Load file
23
- df = pd.read_csv(
24
- unidic_file,
25
- header=None,
26
- names="surface id1 id2 id3 pos1 pos2 pos3 pos4 cType "
27
- "cForm lForm lemma orth orthBase pron pronBase goshu iType iForm fType "
28
- "fForm iConType fConType type kana kanaBase form formBase aType aConType "
29
- "aModType lid lemma_id".split(" "),
30
- )
31
-
32
- df["surface"] = df["surface"].astype(str).str.strip()
33
- df["kana"] = df["kana"].astype(str).str.strip()
34
- df = df[df["kana"] != "*"]
35
- df = df[df["surface"] != df["kana"]]
36
- df = df[["surface", "kana"]]
37
-
38
- df.to_csv(Path(config.READING_DATA_DIR, "unidic.csv"), index=False)
39
-
40
- logger.info("✅ Processed unidic data!")
41
-
42
-
43
- if __name__ == "__main__":
44
- unidic_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yomikata/main.py DELETED
@@ -1,123 +0,0 @@
1
- """main.py
2
- Main entry point for training
3
- """
4
-
5
- import sys
6
- import tempfile
7
- import warnings
8
- from argparse import Namespace
9
- from pathlib import Path
10
-
11
- import mlflow
12
- from datasets import load_dataset
13
-
14
- from config import config
15
- from config.config import logger
16
- from yomikata import utils
17
- from yomikata.dbert import dBert
18
-
19
-
20
- # MLFlow model registry
21
- mlflow.set_tracking_uri("file://" + str(config.RUN_REGISTRY.absolute()))
22
-
23
-
24
- warnings.filterwarnings("ignore")
25
-
26
-
27
- def train_model(
28
- model_name: "dBert",
29
- dataset_name: str = "",
30
- experiment_name: str = "baselines",
31
- run_name: str = "dbert-default",
32
- training_args: dict = {},
33
- ) -> None:
34
- """Train a model given arguments.
35
-
36
- Args:
37
- dataset_name (str): name of the dataset to be trained on. Defaults to the full dataset.
38
- args_fp (str): location of args.
39
- experiment_name (str): name of experiment.
40
- run_name (str): name of specific run in experiment.
41
- """
42
-
43
- mlflow.set_experiment(experiment_name=experiment_name)
44
- with mlflow.start_run(run_name=run_name):
45
-
46
- run_id = mlflow.active_run().info.run_id
47
- logger.info(f"Run ID: {run_id}")
48
-
49
- experiment_id = mlflow.get_run(run_id=run_id).info.experiment_id
50
- artifacts_dir = Path(config.RUN_REGISTRY, experiment_id, run_id, "artifacts")
51
-
52
- # Initialize the model
53
- if model_name == "dBert":
54
- reader = dBert(reinitialize=True, artifacts_dir=artifacts_dir)
55
- else:
56
- raise ValueError("model_name must be dBert for now")
57
-
58
- # Load train val test data
59
- dataset = load_dataset(
60
- "csv",
61
- data_files={
62
- "train": str(
63
- Path(config.TRAIN_DATA_DIR, "train_" + dataset_name + ".csv")
64
- ),
65
- "val": str(Path(config.VAL_DATA_DIR, "val_" + dataset_name + ".csv")),
66
- "test": str(
67
- Path(config.TEST_DATA_DIR, "test_" + dataset_name + ".csv")
68
- ),
69
- },
70
- )
71
-
72
- # Train
73
- training_performance = reader.train(dataset, training_args=training_args)
74
-
75
- # general_performance = evaluate.evaluate(reader, max_evals=20)
76
-
77
- with tempfile.TemporaryDirectory() as dp:
78
- # reader.save(dp)
79
- # utils.save_dict(general_performance, Path(dp, "general_performance.json"))
80
- utils.save_dict(training_performance, Path(dp, "training_performance.json"))
81
- mlflow.log_artifacts(dp)
82
-
83
-
84
- def get_artifacts_dir_from_run(run_id: str):
85
- """Load artifacts directory for a given run_id.
86
-
87
- Args:
88
- run_id (str): id of run to load artifacts from.
89
-
90
- Returns:
91
- Path: path to artifacts directory.
92
-
93
- """
94
-
95
- # Locate specifics artifacts directory
96
- experiment_id = mlflow.get_run(run_id=run_id).info.experiment_id
97
- artifacts_dir = Path(config.RUN_REGISTRY, experiment_id, run_id, "artifacts")
98
-
99
- return artifacts_dir
100
-
101
-
102
- if __name__ == "__main__":
103
-
104
- # get args filepath from input
105
- args_fp = sys.argv[1]
106
-
107
- # load the args_file
108
- args = Namespace(**utils.load_dict(filepath=args_fp)).__dict__
109
-
110
- # pop meta variables
111
- model_name = args.pop("model")
112
- dataset_name = args.pop("dataset")
113
- experiment_name = args.pop("experiment")
114
- run_name = args.pop("run")
115
-
116
- # Perform training
117
- train_model(
118
- model_name=model_name,
119
- dataset_name=dataset_name,
120
- experiment_name=experiment_name,
121
- run_name=run_name,
122
- training_args=args,
123
- )