File size: 6,152 Bytes
a3ffd31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import datetime
from pathlib import Path

import pandas as pd
import torch
from datasets import load_dataset
from tqdm import tqdm

from modules import shared
from modules.logging_colors import logger
from modules.models import clear_torch_cache, load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import encode


def load_past_evaluations():
    if Path('logs/evaluations.csv').exists():
        df = pd.read_csv(Path('logs/evaluations.csv'), dtype=str)
        df['Perplexity'] = pd.to_numeric(df['Perplexity'])
        return df
    else:
        return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])


past_evaluations = load_past_evaluations()


def save_past_evaluations(df):
    global past_evaluations
    past_evaluations = df
    filepath = Path('logs/evaluations.csv')
    filepath.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(filepath, index=False)


def calculate_perplexity(models, input_dataset, stride, _max_length):
    '''
    Based on:
    https://huggingface.co/docs/transformers/perplexity#calculating-ppl-with-fixedlength-models
    '''

    if not shared.args.no_use_fast:
        logger.warning("--no_use_fast is not being used. If tokenizing the input dataset takes a long time, consider loading the model with that option checked.")

    global past_evaluations
    cumulative_log = ''
    cumulative_log += "Loading the input dataset...\n\n"
    yield cumulative_log

    # Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/utils/datautils.py
    if input_dataset == 'wikitext':
        data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
        text = "\n\n".join(data['text'])
    elif input_dataset == 'ptb':
        data = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
        text = "\n\n".join(data['sentence'])
    elif input_dataset == 'ptb_new':
        data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
        text = " ".join(data['sentence'])
    else:
        with open(Path(f'training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f:
            text = f.read()

    for model in models:
        if is_in_past_evaluations(model, input_dataset, stride, _max_length):
            cumulative_log += f"`{model}` has already been tested. Ignoring.\n\n"
            yield cumulative_log
            continue

        if model != 'current model':
            try:
                yield cumulative_log + f"Loading `{model}`...\n\n"
                model_settings = get_model_metadata(model)
                shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})  # hijacking the interface defaults
                update_model_parameters(model_settings)  # hijacking the command-line arguments
                unload_model()
                shared.model, shared.tokenizer = load_model(model)
            except:
                cumulative_log += f"Failed to load `{model}`. Moving on.\n\n"
                yield cumulative_log
                continue

        cumulative_log += f"Processing `{shared.model_name}`...\n\n"
        yield cumulative_log + "Tokenizing the input dataset...\n\n"
        encodings = encode(text, add_special_tokens=False)
        seq_len = encodings.shape[1]
        if _max_length:
            max_length = _max_length
        elif hasattr(shared.model.config, 'max_position_embeddings'):
            max_length = shared.model.config.max_position_embeddings
        else:
            max_length = 2048

        nlls = []
        prev_end_loc = 0
        for begin_loc in tqdm(range(0, seq_len, stride)):
            yield cumulative_log + f"Evaluating... {100*begin_loc/seq_len:.2f}%"
            end_loc = min(begin_loc + max_length, seq_len)
            trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
            input_ids = encodings[:, begin_loc:end_loc]
            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100
            clear_torch_cache()
            with torch.no_grad():
                outputs = shared.model(input_ids=input_ids, labels=target_ids)

                # loss is calculated using CrossEntropyLoss which averages over valid labels
                # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
                # to the left by 1.
                neg_log_likelihood = outputs.loss

            nlls.append(neg_log_likelihood)
            prev_end_loc = end_loc
            if end_loc == seq_len:
                break

        ppl = torch.exp(torch.stack(nlls).mean())
        add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length)
        save_past_evaluations(past_evaluations)
        cumulative_log += f"The perplexity for `{shared.model_name}` is: {float(ppl)}\n\n"
        yield cumulative_log


def add_entry_to_past_evaluations(perplexity, model, dataset, stride, max_length):
    global past_evaluations
    entry = {
        'Model': model,
        'LoRAs': ', '.join(shared.lora_names) or '-',
        'Dataset': dataset,
        'Perplexity': perplexity,
        'stride': str(stride),
        'max_length': str(max_length),
        'Date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'Comment': ''
    }
    past_evaluations = pd.concat([past_evaluations, pd.DataFrame([entry])], ignore_index=True)


def is_in_past_evaluations(model, dataset, stride, max_length):
    entries = past_evaluations[(past_evaluations['Model'] == model) &
                               (past_evaluations['Dataset'] == dataset) &
                               (past_evaluations['max_length'] == str(max_length)) &
                               (past_evaluations['stride'] == str(stride))]

    if entries.shape[0] > 0:
        return True
    else:
        return False


def generate_markdown_table():
    sorted_df = past_evaluations.sort_values(by=['Dataset', 'stride', 'Perplexity', 'Date'])
    return sorted_df