Ezi's picture
Upload 312 files
46df0b6
raw
history blame
No virus
1.73 kB
import logging
import pandas as pd
from datasets import load_metric
from os.path import exists
from os.path import join as pjoin
import utils
from utils import dataset_utils as ds_utils
logs = utils.prepare_logging(__file__)
TOK_MODEL = "gpt2"
PERPLEXITY = load_metric("perplexity")
PERPLEXITY_FIELD = "perplexity"
class DMTHelper:
def __init__(self, dstats, load_only=False):
self.dstats = dstats
self.load_only = load_only
self.results_dict = {}
# Where in the Dataset object to find the text for the calculation
self.text_field = ds_utils.OUR_TEXT_FIELD
# Results in dataframe form
self.df = None
# Cache file
self.perplexities_df_fid = pjoin(self.dstats.dataset_cache_dir,
"perplexities_df.json")
def run_DMT_processing(self):
if self.dstats.use_cache and exists(self.perplexities_df_fid):
self.df = ds_utils.read_df(self.perplexities_df_fid)
elif not self.load_only:
self.prepare_text_perplexities()
if self.dstats.save:
ds_utils.write_df(self.df, self.perplexities_df_fid)
def prepare_text_perplexities(self):
texts = self.dstats.text_dset[self.text_field]
eval_results = PERPLEXITY.compute(input_texts=texts, model_id=TOK_MODEL)
# TODO: What other stuff might be useful to grab?
self.results_dict = {PERPLEXITY_FIELD: eval_results["perplexities"],
self.text_field: self.dstats.text_dset[self.text_field]}
self.df = pd.DataFrame(self.results_dict).sort_values(
by=PERPLEXITY_FIELD, ascending=False)
def get_df(self):
return self.df