File size: 1,731 Bytes
46df0b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import logging
import pandas as pd
from datasets import load_metric
from os.path import exists
from os.path import join as pjoin
import utils
from utils import dataset_utils as ds_utils

logs = utils.prepare_logging(__file__)

TOK_MODEL = "gpt2"
PERPLEXITY = load_metric("perplexity")
PERPLEXITY_FIELD = "perplexity"


class DMTHelper:
    def __init__(self, dstats, load_only=False):
        self.dstats = dstats
        self.load_only = load_only
        self.results_dict = {}
        # Where in the Dataset object to find the text for the calculation
        self.text_field = ds_utils.OUR_TEXT_FIELD
        # Results in dataframe form
        self.df = None
        # Cache file
        self.perplexities_df_fid = pjoin(self.dstats.dataset_cache_dir,
                                         "perplexities_df.json")

    def run_DMT_processing(self):
        if self.dstats.use_cache and exists(self.perplexities_df_fid):
            self.df = ds_utils.read_df(self.perplexities_df_fid)
        elif not self.load_only:
            self.prepare_text_perplexities()
            if self.dstats.save:
                ds_utils.write_df(self.df, self.perplexities_df_fid)

    def prepare_text_perplexities(self):
        texts = self.dstats.text_dset[self.text_field]
        eval_results = PERPLEXITY.compute(input_texts=texts, model_id=TOK_MODEL)
        # TODO: What other stuff might be useful to grab?
        self.results_dict = {PERPLEXITY_FIELD: eval_results["perplexities"],
                             self.text_field: self.dstats.text_dset[self.text_field]}
        self.df = pd.DataFrame(self.results_dict).sort_values(
            by=PERPLEXITY_FIELD, ascending=False)

    def get_df(self):
        return self.df