SummaryProject / src /dataloader.py
EveSa
Merge pull request #15 from EveSa/Eve
05ee545 unverified
raw
history blame
7.41 kB
"""
Get data and adapt it for training
-----------
- nettoyage de l'encodage
- Ajout de token <START> et <END>
TO DO :
- Nettoyage des contractions
- enlever les \xad
- enlever ponctuation et () []
- s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
Création d'un Vectoriserà partir du vocabulaire :
"""
import pickle
import string
from collections import Counter
import pandas as pd
import torch
class Data(torch.utils.data.Dataset):
"""
A class used to get data from file
...
Attributes
----------
path : str
the path to the file containing the data
Methods
-------
open()
open the jsonl file with pandas
clean_data(text_type)
clean the data got by opening the file and adds <start> and
<end> tokens depending on the text_type
get_words()
get the dataset vocabulary
"""
def __init__(self, path: str, transform=None) -> None:
self.path = path
self.data = pd.read_json(path_or_buf=self.path, lines=True)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
text = row["text"].translate(
str.maketrans(
"", "", string.punctuation)).split()
summary = (
row["summary"].translate(
str.maketrans(
"",
"",
string.punctuation)).split())
summary = ["<start>", *summary, "<end>"]
sample = {"text": text, "summary": summary}
if self.transform:
sample = self.transform(sample)
return sample
def open(self) -> pd.DataFrame:
"""
Open the file containing the data
"""
return pd.read_json(path_or_buf=self.path, lines=True)
def clean_data(self, text_type: str) -> list:
"""
Clean data from encoding error, punctuation, etc...
To Do :
#nettoyer les données
Parameters
----------
text_type : str
allow to differenciate between 'text' and 'summary'
to add <start> and <end> tokens to summaries
Returns
----------
list of list
list of tokenised texts
"""
dataset = self.open()
texts = dataset[text_type]
texts = texts.str.encode("cp1252", "ignore")
texts = texts.str.decode("utf-8", "ignore")
tokenized_texts = []
# - Nettoyage des contractions
# - enlever les \xad
# text.translate(str.maketrans('', '', string.punctuation))
# - enlever ponctuation et () []
# - s'occuper des noms propres (mots commençant par une majuscule qui se suivent)
for text in texts:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.split()
tokenized_texts.append(text)
if text_type == "summary":
return [["<start>", *summary, "<end>"]
for summary in tokenized_texts]
return tokenized_texts
def get_words(self) -> list:
"""
Create a dictionnary of the data vocabulary
"""
texts, summaries = self.clean_data("text"), self.clean_data("summary")
text_words = [word for text in texts for word in text]
summary_words = [word for text in summaries for word in text]
return text_words + summary_words
def pad_collate(data):
text_batch = [element[0] for element in data]
summary_batch = [element[1] for element in data]
max_len = max([len(element) for element in summary_batch + text_batch])
text_batch = [
torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
for element in text_batch
]
summary_batch = [
torch.nn.functional.pad(element, (0, max_len - len(element)), value=-100)
for element in summary_batch
]
return text_batch, summary_batch
class Vectoriser:
"""
A class used to vectorise data
...
Attributes
----------
vocab : list
list of the vocab
Methods
-------
encode(tokens)
transforms a list of tokens to their corresponding idx
in form of troch tensor
decode(word_idx_tensor)
converts a tensor to a list of tokens
vectorize(row)
encode an entire row from the dataset
"""
def __init__(self, vocab=None) -> None:
self.vocab = vocab
self.word_count = Counter(word.lower().strip(",.\\-")
for word in self.vocab)
self.idx_to_token = sorted(
[t for t, c in self.word_count.items() if c > 1])
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
def load(self, path):
with open(path, "rb") as file:
self.vocab = pickle.load(file)
self.word_count = Counter(
word.lower().strip(",.\\-") for word in self.vocab
)
self.idx_to_token = sorted(
[t for t, c in self.word_count.items() if c > 1])
self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}
def save(self, path):
with open(path, "wb") as file:
pickle.dump(self.vocab, file)
def encode(self, tokens) -> torch.tensor:
"""
Encode une phrase selon les mots qu'elle contient
selon les mots contenus dans le dictionnaire.
À NOTER :
Si un mot n'est pas contenu dans le dictionnaire,
associe un index fixe au mot qui sera ignoré au décodage.
---------
:params: tokens : list
les mots de la phrase sous forme de liste
:return: words_idx : tensor
Un tensor contenant les index des mots de la phrase
"""
if isinstance(tokens, list):
words_idx = torch.tensor(
[
self.token_to_idx.get(t.lower(), len(self.token_to_idx))
for t in tokens
],
dtype=torch.long,
)
# Permet d'encoder mots par mots
elif isinstance(tokens, str):
words_idx = torch.tensor(self.token_to_idx.get(tokens.lower()))
return words_idx
def decode(self, words_idx_tensor) -> list:
"""
Decode une phrase selon le procédé inverse que la fonction encode
"""
idxs = words_idx_tensor.tolist()
if isinstance(idxs, int):
words = [self.idx_to_token[idxs]]
else:
words = []
for idx in idxs:
if idx != len(self.idx_to_token):
words.append(self.idx_to_token[idx])
return words
def __call__(self, row) -> torch.tensor:
"""
Encode les données d'une ligne du dataframe
----------
:params: row : dataframe
une ligne du dataframe (un coupe texte-résumé)
:returns: text_idx : tensor
le tensor correspondant aux mots du textes
:returns: summary_idx: tensor
le tensr correspondant aux mots du résumé
"""
text_idx = self.encode(row["text"])
summary_idx = self.encode(row["summary"])
return (text_idx, summary_idx)