koCSN_SAPR / utils /data_prep.py
yuneun92's picture
Upload 13 files
bcb1848 verified
raw
history blame
No virus
16.2 kB
"""
Author:
"""
import copy
from typing import Any
from ckonlpy.tag import Twitter
from tqdm import tqdm
import re
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
# ์‚ฌ์šฉ์ž๊ฐ€ ์‚ฌ์ „์— ๋‹จ์–ด ์ถ”๊ฐ€๊ฐ€ ๊ฐ€๋Šฅํ•œ ํ˜•ํƒœ์†Œ ๋ถ„์„๊ธฐ๋ฅผ ์ด์šฉ(์ถ”ํ›„์— name_list์— ๋“ฑ์žฌ๋œ ์ด๋ฆ„์„ ๋“ฑ๋กํ•˜์—ฌ ์ธ์‹ ๋ฐ ๋ถ„๋ฆฌํ•˜๊ธฐ ์œ„ํ•จ)
twitter = Twitter()
def load_data(filename) -> Any:
"""
์ง€์ •๋œ ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
"""
return torch.load(filename)
def NML(seg_sents, mention_positions, ws):
"""
Nearest Mention Location (ํŠน์ • ํ›„๋ณด ๋ฐœํ™”์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜์ค‘, ์ธ์šฉ๋ฌธ์œผ๋กœ๋ถ€ํ„ฐ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์–ธ๊ธ‰ ์œ„์น˜๋ฅผ ์ฐพ๋Š” ํ•จ์ˆ˜)
Parameters:
- seg_sents: ๋ฌธ์žฅ์„ ๋ถ„ํ• ํ•œ ๋ฆฌ์ŠคํŠธ
- mention_positions: ํŠน์ • ํ›„๋ณด ๋ฐœํ™”์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ๋ชจ๋‘ ๋‹ด์€ ๋ฆฌ์ŠคํŠธ [(sentence_index, word_index), ...]
- ws: ์ธ์šฉ๋ฌธ ์•ž/๋’ค๋กœ ๊ณ ๋ คํ•  ๋ฌธ์žฅ์˜ ์ˆ˜
Returns:
- ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์–ธ๊ธ‰ ์œ„์น˜์˜ (sentence_index, word_index)
"""
def word_dist(pos):
"""
๋ฐœํ™” ํ›„๋ณด์ž ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜์™€ ์ธ์šฉ๋ฌธ ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ๋ฅผ ๋‹จ์–ด ์ˆ˜์ค€(word level)์—์„œ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Parameters:
- pos: ๋ฐœํ™” ํ›„๋ณด์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜ (sentence_index, word_index)
Returns:
- ๋ฐœํ™” ํ›„๋ณด์ž์™€ ์–ธ๊ธ‰๋œ ์œ„์น˜ ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ (๋‹จ์–ด ์ˆ˜์ค€)
"""
if pos[0] == ws:
w_d = ws * 2
elif pos[0] < ws:
w_d = sum(len(
sent) for sent in seg_sents[pos[0] + 1:ws]) + len(seg_sents[pos[0]][pos[1] + 1:])
else:
w_d = sum(
len(sent) for sent in seg_sents[ws + 1:pos[0]]) + len(seg_sents[pos[0]][:pos[1]])
return w_d
# ์–ธ๊ธ‰๋œ ์œ„์น˜๋“ค๊ณผ ์ธ์šฉ๋ฌธ ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ€๊นŒ์šด ์ˆœ์œผ๋กœ ์ •๋ ฌ
sorted_positions = sorted(mention_positions, key=lambda x: word_dist(x))
# ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์–ธ๊ธ‰ ์œ„์น˜(Nearest Mention Location) ๋ฐ˜ํ™˜
return sorted_positions[0]
def max_len_cut(seg_sents, mention_pos, max_len):
"""
์ฃผ์–ด์ง„ ๋ฌธ์žฅ์„ ๋ชจ๋ธ์— ์ž…๋ ฅ ๊ฐ€๋Šฅํ•œ ์ตœ๋Œ€ ๊ธธ์ด(max_len)๋กœ ์ž๋ฅด๋Š” ํ•จ์ˆ˜
Parameters:
- seg_sents: ๋ฌธ์žฅ์„ ๋ถ„ํ• ํ•œ ๋ฆฌ์ŠคํŠธ
- mention_pos: ๋ฐœํ™” ํ›„๋ณด์ž๊ฐ€ ์–ธ๊ธ‰๋œ ์œ„์น˜ (sentence_index, word_index)
- max_len: ์ž…๋ ฅ ๊ฐ€๋Šฅํ•œ ์ตœ๋Œ€ ๊ธธ์ด
Returns:
- seg_sents : ์ž๋ฅด๊ณ  ๋‚จ์€ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ
- mention_pos : ์กฐ์ •๋œ ์–ธ๊ธ‰๋œ ์œ„์น˜
"""
# ๊ฐ ๋ฌธ์žฅ์˜ ๊ธธ์ด๋ฅผ ๋ฌธ์ž ๋‹จ์œ„๋กœ ๊ณ„์‚ฐํ•œ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
sent_char_lens = [sum(len(word) for word in sent) for sent in seg_sents]
# ์ „์ฒด ๋ฌธ์ž์˜ ๊ธธ์ด ํ•ฉ
sum_char_len = sum(sent_char_lens)
# ๊ฐ ๋ฌธ์žฅ์—์„œ, cut์„ ์‹คํ–‰ํ•  ๋ฌธ์ž์˜ ์œ„์น˜(๋งจ ๋งˆ์ง€๋ง‰ ๋ฌธ์ž)
running_cut_idx = [len(sent) - 1 for sent in seg_sents]
while sum_char_len > max_len:
max_len_sent_idx = max(list(enumerate(sent_char_lens)), key=lambda x: x[1])[0]
if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] == mention_pos[1]:
running_cut_idx[max_len_sent_idx] -= 1
if max_len_sent_idx == mention_pos[0] and running_cut_idx[max_len_sent_idx] < mention_pos[1]:
mention_pos[1] -= 1
reduced_char_len = len(
seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]])
sent_char_lens[max_len_sent_idx] -= reduced_char_len
sum_char_len -= reduced_char_len
# ์ž๋ฅผ ์œ„์น˜ ์‚ญ์ œ
del seg_sents[max_len_sent_idx][running_cut_idx[max_len_sent_idx]]
# ์ž๋ฅผ ์œ„์น˜ ์—…๋ฐ์ดํŠธ
running_cut_idx[max_len_sent_idx] -= 1
return seg_sents, mention_pos
def seg_and_mention_location(raw_sents_in_list, alias2id):
"""
์ฃผ์–ด์ง„ ๋ฌธ์žฅ์„ ๋ถ„ํ• ํ•˜๊ณ  ๋ฐœํ™”์ž ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ์ฐพ๋Š” ํ•จ์ˆ˜
Parameters:
- raw_sents_in_list: ๋ถ„ํ• ํ•  ์›๋ณธ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ
- alias2id: ์บ๋ฆญํ„ฐ ๋ณ„ ์ด๋ฆ„(๋ฐ ๋ณ„์นญ)๊ณผ ID๋ฅผ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ
Returns:
- seg_sents: ๋ฌธ์žฅ์„ ๋‹จ์–ด๋กœ ๋ถ„ํ• ํ•œ ๋ฆฌ์ŠคํŠธ
- character_mention_poses: ์บ๋ฆญํ„ฐ๋ณ„๋กœ, ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ๋ชจ๋‘ ์ €์žฅํ•œ ๋”•์…”๋„ˆ๋ฆฌ {character1_id: [[sent_idx, word_idx], ...]}
- name_list_index: ์–ธ๊ธ‰๋œ ์บ๋ฆญํ„ฐ ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ
"""
character_mention_poses = {}
seg_sents = []
id_pattern = ['&C{:02d}&'.format(i) for i in range(51)]
for sent_idx, sent in enumerate(raw_sents_in_list):
raw_sent_with_split = sent.split()
for word_idx, word in enumerate(raw_sent_with_split):
match = re.search(r'&C\d{1,2}&', word)
# &C00& ํ˜•์‹์œผ๋กœ ๋œ ์ด๋ฆ„์ด ์žˆ์„ ๊ฒฝ์šฐ, result ๋ณ€์ˆ˜๋กœ ์ง€์ •
if match:
result = match.group(0)
if alias2id[result] in character_mention_poses:
character_mention_poses[alias2id[result]].append([sent_idx, word_idx])
else:
character_mention_poses[alias2id[result]] = [[sent_idx, word_idx]]
seg_sents.append(raw_sent_with_split)
name_list_index = list(character_mention_poses.keys())
return seg_sents, character_mention_poses, name_list_index
def create_CSS(seg_sents, candidate_mention_poses, args):
"""
๊ฐ ์ธ์Šคํ„ด์Šค ๋‚ด ๊ฐ ๋ฐœํ™”์ž ํ›„๋ณด(candidate)์— ๋Œ€ํ•˜์—ฌ candidate-specific segments(CSS)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
parameters:
seg_sents: 2ws + 1 ๊ฐœ์˜ ๋ฌธ์žฅ(๊ฐ ๋ฌธ์žฅ์€ ๋ถ„ํ• ๋จ)๋“ค์„ ๋‹ด์€ ๋ฆฌ์ŠคํŠธ
candidate_mention_poses: ๋ฐœํ™”์ž๋ณ„๋กœ ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ๋”•์…”๋„ˆ๋ฆฌ์ด๋ฉฐ, ํ˜•ํƒœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Œ.
{character index: [[sentence index, word index in sentence] of mention 1,...]...}.
args : ์‹คํ–‰ ์ธ์ˆ˜๋ฅผ ๋‹ด์€ ๊ฐ์ฒด
return:
Returned contents are in lists, in which each element corresponds to a candidate.
The order of candidate is consistent with that in list(candidate_mention_poses.keys()).
many_css: ๊ฐ ๋ฐœํ™”์ž ํ›„๋ณด์— ๋Œ€ํ•œ candidate-specific segments(CSS).
many_sent_char_len: ๊ฐ CSS์˜ ๋ฌธ์ž ๊ธธ์ด ์ •๋ณด
[[character-level length of sentence 1,...] of the CSS of candidate 1,...].
many_mention_pos: CSS ๋‚ด์—์„œ, ์ธ์šฉ๋ฌธ๊ณผ ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์ด๋ฆ„์ด ์–ธ๊ธ‰๋œ ์œ„์น˜ ์ •๋ณด
[(sentence-level index of nearest mention in CSS,
character-level index of the leftmost character of nearest mention in CSS,
character-level index of the rightmost character + 1) of candidate 1,...].
many_quote_idx: CSS ๋‚ด์˜ ์ธ์šฉ๋ฌธ์˜ ๋ฌธ์žฅ ์ธ๋ฑ์Šค
many_cut_css : ์ตœ๋Œ€ ๊ธธ์ด ์ œํ•œ์ด ์ ์šฉ๋œ CSS
"""
ws = args.ws
max_len = args.length_limit
model_name = args.model_name
# assert len(seg_sents) == ws * 2 + 1
many_css = []
many_sent_char_lens = []
many_mention_poses = []
many_quote_idxes = []
many_cut_css = []
for candidate_idx in candidate_mention_poses.keys():
nearest_pos = NML(seg_sents, candidate_mention_poses[candidate_idx], ws)
if nearest_pos[0] <= ws:
CSS = copy.deepcopy(seg_sents[nearest_pos[0]:ws + 1])
mention_pos = [0, nearest_pos[1]]
quote_idx = ws - nearest_pos[0]
else:
CSS = copy.deepcopy(seg_sents[ws:nearest_pos[0] + 1])
mention_pos = [nearest_pos[0] - ws, nearest_pos[1]]
quote_idx = 0
cut_CSS, mention_pos = max_len_cut(CSS, mention_pos, max_len)
sent_char_lens = [sum(len(word) for word in sent) for sent in cut_CSS]
mention_pos_left = sum(sent_char_lens[:mention_pos[0]]) + sum(
len(x) for x in cut_CSS[mention_pos[0]][:mention_pos[1]])
mention_pos_right = mention_pos_left + len(cut_CSS[mention_pos[0]][mention_pos[1]])
if model_name == 'CSN':
mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right)
cat_CSS = ''.join([''.join(sent) for sent in cut_CSS])
elif model_name == 'KCSN':
mention_pos = (mention_pos[0], mention_pos_left, mention_pos_right, mention_pos[1])
cat_CSS = ' '.join([' '.join(sent) for sent in cut_CSS])
many_css.append(cat_CSS)
many_sent_char_lens.append(sent_char_lens)
many_mention_poses.append(mention_pos)
many_quote_idxes.append(quote_idx)
many_cut_css.append(cut_CSS)
return many_css, many_sent_char_lens, many_mention_poses, many_quote_idxes, many_cut_css
class ISDataset(Dataset):
"""
๋ฐœํ™”์ž ์‹๋ณ„์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ์…‹ ์„œ๋ธŒํด๋ž˜์Šค
"""
def __init__(self, data_list):
super(ISDataset, self).__init__()
self.data = data_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def build_data_loader(data_file, alias2id, args, save_name=None) -> DataLoader:
"""
ํ•™์Šต์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
"""
# ์‚ฌ์ „์— ์ด๋ฆ„์„ ์ถ”๊ฐ€
for alias in alias2id:
twitter.add_dictionary(alias, 'Noun')
# ํŒŒ์ผ์„ ์ค„๋ณ„๋กœ ๋ถˆ๋Ÿฌ๋“ค์ž„
with open(data_file, 'r', encoding='utf-8') as fin:
data_lines = fin.readlines()
# ์ „์ฒ˜๋ฆฌ
data_list = []
for i, line in enumerate(tqdm(data_lines)):
offset = i % 31
if offset == 0:
instance_index = line.strip().split()[-1]
raw_sents_in_list = []
continue
if offset < 22:
raw_sents_in_list.append(line.strip())
if offset == 22:
speaker_name = line.strip().split()[-1]
# ๋นˆ ๋ฆฌ์ŠคํŠธ๋Š” ์ œ๊ฑฐ
filtered_list = [li for li in raw_sents_in_list if li]
# ๋ฌธ์žฅ ๋ถ„ํ•  ๋ฐ ๋“ฑ์žฅ์ธ๋ฌผ ์–ธ๊ธ‰ ์œ„์น˜ ์ถ”์ถœ
seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
filtered_list, alias2id)
# CSS ์ƒ์„ฑ
css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS(
seg_sents, candidate_mention_poses, args)
# ํ›„๋ณด์ž ๋ฆฌ์ŠคํŠธ
candidates_list = list(candidate_mention_poses.keys())
# ์›ํ•ซ ๋ ˆ์ด๋ธ” ์ƒ์„ฑ
one_hot_label = [0 if character_idx != alias2id[speaker_name]
else 1 for character_idx in candidate_mention_poses.keys()]
true_index = one_hot_label.index(1) if 1 in one_hot_label else 0
if offset == 24:
category = line.strip().split()[-1]
if offset == 25:
name = ' '.join(line.strip().split()[1:])
if offset == 26:
scene = line.strip().split()[-1]
if offset == 27:
place = line.strip().split()[-1]
if offset == 28:
time = line.strip().split()[-1]
if offset == 29:
cut_position = line.strip().split()[-1]
data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
cut_css, one_hot_label, true_index, category, name_list_index,
name, scene, place, time, cut_position, candidates_list,
instance_index))
# ๋ฐ์ดํ„ฐ๋กœ๋” ์ƒ์„ฑ
data_loader = DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
# ์ €์žฅํ•  ์ด๋ฆ„์ด ์ฃผ์–ด์ง„ ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ ์ €์žฅ
if save_name is not None:
torch.save(data_list, save_name)
return data_loader
def load_data_loader(saved_filename: str) -> DataLoader:
"""
์ €์žฅ๋œ ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๊ณ  DataLoader ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
# ์ €์žฅ๋œ ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ ๋กœ๋“œ
data_list = load_data(saved_filename)
return DataLoader(ISDataset(data_list), batch_size=1, collate_fn=lambda x: x[0])
def split_train_val_test(data_file, alias2id, args, save_name=None, test_size=0.2, val_size=0.1, random_state=13):
"""
๊ธฐ์กด ๊ฒ€์ฆ ๋ฐฉ์‹์„ ์ ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๋นŒ๋“œํ•ฉ๋‹ˆ๋‹ค.
์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ ํŒŒ์ผ์„ ํ›ˆ๋ จ, ๊ฒ€์ฆ, ํ…Œ์ŠคํŠธ ์„ธํŠธ๋กœ ๋ถ„ํ• ํ•˜๊ณ  ๊ฐ๊ฐ์˜ DataLoader๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
Parameters:
- data_file: ๋ถ„ํ• ํ•  ๋ฐ์ดํ„ฐ ํŒŒ์ผ ๊ฒฝ๋กœ
- alias2id: ๋“ฑ์žฅ์ธ๋ฌผ ์ด๋ฆ„๊ณผ ID๋ฅผ ๋งคํ•‘ํ•œ ๋”•์…”๋„ˆ๋ฆฌ
- args: ์‹คํ–‰ ์ธ์ž๋ฅผ ๋‹ด์€ ๊ฐ์ฒด
- save_name: ๋ถ„ํ• ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•  ํŒŒ์ผ ์ด๋ฆ„
- test_size: ํ…Œ์ŠคํŠธ ์„ธํŠธ์˜ ๋น„์œจ (๊ธฐ๋ณธ๊ฐ’: 0.2)
- val_size: ๊ฒ€์ฆ ์„ธํŠธ์˜ ๋น„์œจ (๊ธฐ๋ณธ๊ฐ’: 0.1)
- random_state: ๋žœ๋ค ์‹œ๋“œ (๊ธฐ๋ณธ๊ฐ’: 13)
Returns:
- train_loader: ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋กœ๋”
- val_loader: ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ๋กœ๋”
- test_loader: ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ๋”
"""
# ์‚ฌ์ „์— ์ด๋ฆ„ ์ถ”๊ฐ€
for alias in alias2id:
twitter.add_dictionary(alias, 'Noun')
# ํŒŒ์ผ์—์„œ ์ธ์Šคํ„ด์Šค ๋กœ๋“œ
with open(data_file, 'r', encoding='utf-8') as fin:
data_lines = fin.readlines()
# ์ „์ฒ˜๋ฆฌ
data_list = []
for i, line in enumerate(tqdm(data_lines)):
offset = i % 31
if offset == 0:
instance_index = line.strip().split()[-1]
raw_sents_in_list = []
continue
if offset < 22:
raw_sents_in_list.append(line.strip())
if offset == 22:
speaker_name = line.strip().split()[-1]
# ๋นˆ ๋ฆฌ์ŠคํŠธ๋Š” ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค.
filtered_list = [li for li in raw_sents_in_list if li]
# ๋ฌธ์žฅ ๋ถ„ํ•  ๋ฐ ๋“ฑ์žฅ์ธ๋ฌผ ์–ธ๊ธ‰ ์œ„์น˜ ์ถ”์ถœ
seg_sents, candidate_mention_poses, name_list_index = seg_and_mention_location(
filtered_list, alias2id)
# CSS ์ƒ์„ฑ
css, sent_char_lens, mention_poses, quote_idxes, cut_css = create_CSS(
seg_sents, candidate_mention_poses, args)
# ํ›„๋ณด์ž ๋ฆฌ์ŠคํŠธ
candidates_list = list(candidate_mention_poses.keys())
# ์›ํ•ซ ๋ ˆ์ด๋ธ” ์ƒ์„ฑ
one_hot_label = [0 if character_idx != alias2id[speaker_name]
else 1 for character_idx in candidate_mention_poses.keys()]
true_index = one_hot_label.index(1) if 1 in one_hot_label else 0
if offset == 24:
category = line.strip().split()[-1]
if offset == 25:
name = ' '.join(line.strip().split()[1:])
if offset == 26:
scene = line.strip().split()[-1]
if offset == 27:
place = line.strip().split()[-1]
if offset == 28:
time = line.strip().split()[-1]
if offset == 29:
cut_position = line.strip().split()[-1]
data_list.append((seg_sents, css, sent_char_lens, mention_poses, quote_idxes,
cut_css, one_hot_label, true_index, category, name_list_index,
name, scene, place, time, cut_position, candidates_list,
instance_index))
# train-validation-test๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋‚˜๋ˆ„๊ธฐ
train_data, test_data = train_test_split(
data_list, test_size=test_size, random_state=random_state)
train_data, val_data = train_test_split(
train_data, test_size=val_size, random_state=random_state)
# train DataLoader ์ƒ์„ฑ
train_loader = DataLoader(ISDataset(train_data), batch_size=1, collate_fn=lambda x: x[0])
# validation DataLoader ์ƒ์„ฑ
val_loader = DataLoader(ISDataset(val_data), batch_size=1, collate_fn=lambda x: x[0])
# test DataLoader ์ƒ์„ฑ
test_loader = DataLoader(ISDataset(test_data), batch_size=1, collate_fn=lambda x: x[0])
if save_name is not None:
# ๊ฐ๊ฐ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅ
torch.save(train_data, save_name.replace(".pt", "_train.pt"))
torch.save(val_data, save_name.replace(".pt", "_val.pt"))
torch.save(test_data, save_name.replace(".pt", "_test.pt"))
return train_loader, val_loader, test_loader