rm_byt5_base / README.md
voidful's picture
Create README.md
65f0dc5
import json

import matplotlib.pyplot as plt
import numpy as np
import torch
from ranx import evaluate
from tqdm.auto import tqdm

from rm_model import humanPreferenceModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a list of model configurations
model_configs = [
    {
        "name": 'rm_byt5_base',
        "config": "google/byt5-base",
        "path": 'voidful/rm_byt5_base',
    }
]

eval_dataset = "test_rm.jsonl"
# one data example: {"question": "Screenshot Software recommendation - free, Windows XP/7", "answers": ["My favourite: FSCapture 5.3 (last free version)\nPortable, lightweight, free.\n\n", "Use Irfan View, is is faster than XnView and allows to set up a capture hotkey, or start capturing with a delay (possible via hotkey too).\n", "I know you are looking for a free solution; this is more of an FYI, in case you have Microsoft OneNote...\nYou can press Win - S to take a screenshot that is pasted inside the OneNote program...Then right-click the image (while it is selected), and click \"Save As\".  You can then save the image anywhere you like...\n"], "accepted_answer": ["Windows 7 comes with the snipping tool, which can be activated via hotkey with a little tweaking.\nSome nifty third party tools include Cropper:\n\nGreenshot:\n\nand of course, Gadwin.\n"]} 
maxlen = 512
batch_size = 3


def rank_answers(model, question, answers):
    model.eval()
    with torch.inference_mode():
        inputs = model.tokenizer([f"question: {question} answer: {answer}" for answer in answers], return_tensors="pt",
                                 padding=True, truncation=True, max_length=maxlen).to(device)
        decoder_input_ids = model.transformer_model._shift_right(inputs["input_ids"])
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
                        decoder_input_ids=decoder_input_ids)
        answer_scores = outputs.cpu()
        return list(zip(answers, answer_scores))


def create_test_data():
    testing_data = []
    with open(eval_dataset, "r", encoding="utf8") as f:
        for line in f:
            testing_data.append(json.loads(line))

    return testing_data


def create_qrels_and_run(test_data, model):
    qrels = {}
    run = {}
    selected_scores = []
    nonselected_scores = []
    query_id = 0

    for example in tqdm(test_data):
        question = example["question"]
        correct_answer = example["accepted_answer"][0]
        answers = example["answers"] + example["accepted_answer"]
        ranked_answers = rank_answers(model, question, answers)

        qrels[query_id] = {i: int(answer == correct_answer) for i, answer in enumerate(answers)}
        run[query_id] = {i: score for i, (_, score) in enumerate(ranked_answers)}

        for answer, score in ranked_answers:
            if answer == correct_answer:
                selected_scores.append(score.cpu().detach().numpy())
            else:
                nonselected_scores.append(score.cpu().detach().numpy())
        query_id += 1

    return qrels, run, selected_scores, nonselected_scores


# Wrap your current code inside a function
def evaluate_model(model_config, model_name, model_path):
    model = humanPreferenceModel(model_config)
    if model_path:
        model.load_state_dict(torch.load(model_path, map_location='cuda:0'))

    model.eval()

    test_data = create_test_data()
    qrels, run, selected_scores, nonselected_scores = create_qrels_and_run(test_data, model)

    # Compute mean score for selected and non-selected answers
    mean_selected_score = np.mean(selected_scores) if len(selected_scores) > 0 else 0
    mean_nonselected_score = np.mean(nonselected_scores)
    print(f"Mean score for selected answers: {mean_selected_score:.4f}")
    print(f"Mean score for non-selected answers: {mean_nonselected_score:.4f}")
    print("Selected scores:", len(selected_scores), selected_scores[:5])
    print("Non-selected scores:", len(nonselected_scores), nonselected_scores[:5])

    # Evaluate and print results
    metrics_to_compute = ["hits@5", "hit_rate@5", "precision@5", "recall@5", "f1@5", "r-precision", "bpref", "rbp.95",
                          "mrr@5", "map@5", "ndcg@5", "ndcg_burges@5"]
    results = evaluate(qrels, run, metrics_to_compute)
    print(results)
    results_perc = {metric: result * 100 for metric, result in results.items()}

    selected_scores_flat = [score.item() for score in selected_scores]
    nonselected_scores_flat = [score.item() for score in nonselected_scores]

    statistics = {'mean': np.mean}

    plt.hist(nonselected_scores_flat, bins=100, alpha=0.3, label='Non-selected answers')
    plt.hist(selected_scores_flat, bins=100, alpha=0.3, label='Selected answers')

    colors = {'selected': 'peru', 'non-selected': 'steelblue'}
    linestyles = ['dashed', 'dashed', 'dotted', 'dotted', 'dotted']

    for idx, (stat_name, stat_func) in enumerate(statistics.items()):
        for group_idx, group in enumerate(['non-selected', 'selected']):
            scores = selected_scores_flat if group == 'selected' else nonselected_scores_flat
            stat_value = stat_func(scores)
            plt.axvline(stat_value, color=colors[group], linestyle=linestyles[idx], linewidth=1)
            y_pos = plt.ylim()[1] * (0.9 - (idx * 2 + group_idx) * 0.05)
            x_offset = plt.xlim()[1] * 0.01
            plt.text(stat_value + x_offset, y_pos, f"{stat_name}: {stat_value:.2f}", color=colors[group], ha='left',
                     fontsize=8)

    plt.legend(loc='best', bbox_to_anchor=(1, 1))
    ax = plt.gca()
    legend = ax.get_legend()
    result_str = '\n'.join([f"{metric}: {result:.2f}%" for metric, result in results_perc.items()])
    plt.text(plt.xlim()[1] * 1.05, plt.ylim()[0] + (plt.ylim()[1] - plt.ylim()[0]) * 0.05, result_str, fontsize=8)
    plt.subplots_adjust(right=0.8)
    legend.set_bbox_to_anchor((1, 1))
    plt.title('Score distribution for selected and non-selected answers')
    plt.xlabel('Score')
    plt.ylabel('Frequency')
    plt.savefig(f'score_distribution_answers_{model_name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    return results, selected_scores, nonselected_scores


# Iterate over model configurations
for config in model_configs:
    results, selected_scores, nonselected_scores = evaluate_model(config['config'], config['name'], config['path'])
    print(f"Results for {config['name']}: {results}")