pacing-judge / README.md
ZachW's picture
Update README.md
62a5638 verified
|
raw
history blame
7.41 kB
metadata
license: mit
datasets:
  - ZachW/GPT-BookSum
language:
  - en
metrics:
  - accuracy
base_model:
  - FacebookAI/roberta-base
pipeline_tag: zero-shot-classification
tags:
  - pacing
  - concreteness
  - text-evalutaion

Pacing-Judge

[project page]

Overview

This is the concreteness evaluator developed in the paper Improving Pacing in Long-Form Story Planning (EMNLP 2023).

Quick Start

A simple usage: Input a pair of texts (text_ex_1, text_ex_2) with <sep> as the separator to the model. The output is whether the first or the second is more concrete.

import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "ZachW/pacing-judge"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_ex_1 = "The Duke then focused on securing his power and looking to future threats. The Duke eventually turned his attention to acquiring Tuscany but struggled."
text_ex_2 = "Lord Bacon mentioned his book \"The History of Henry VII,\" in the conversation noting that King Charles had conquered Naples without resistance, implying that the conquest was like a dream."
inputs = tokenizer(text_ex_1 + " <sep> " + text_ex_2, return_tensors="pt")
outputs = model(**inputs)
output = int(F.softmax(outputs.logits, dim=1)[:, 0].squeeze(-1).detach().cpu().numpy() > 0.5)
print(f"Output Binary = {output}")
if output:
    print("The second text is more concrete.")
else:
    print("The first text is more concrete.")

Usage

We have designed this Ranker, which enables fair pairwise comparison (independent of sequence order) and ranking among candidates. We recommend using our model via the Ranker.

import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer

class Ranker:
    def __init__(self):
        print(f"*** Loading Model from Huggingface ***")
        model_name = "ZachW/pacing-judge"
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def compare(self, t1, t2):
        text_pair = [t1 + ' <sep> ' + t2, t2 + ' <sep> ' + t1]
        pair_dataset = self.tokenizer(text_pair, padding=True, truncation=True, return_tensors="pt")
        score = self.run_model(pair_dataset)
        if score < 0.5:
            return 0 # first is more concrete
        else:
            return 1 # second is more concrete
    
    def compare_logits(self, t1, t2):
        text_pair = [t1 + ' <sep> ' + t2, t2 + ' <sep> ' + t1]
        pair_dataset = self.tokenizer(text_pair, padding=True, truncation=True, return_tensors="pt")
        score = self.run_model(pair_dataset)
        return score

    def run_model(self, dataset):
        outputs = self.model(**dataset)
        scores = F.softmax(outputs.logits, dim=1)[:, 0].squeeze(-1).detach().cpu().numpy()
        aver_score = (scores[0] + (1 - scores[1]))/2
        return aver_score
    
    def rank(self, texts_list): # input a list of texts
        def quicksort(arr):
            if len(arr) <= 1:
                return arr
            else:
                pivot = arr[0]
                less = []
                greater = []
                for t in arr[1:]:
                    cmp = self.compare(pivot, t)
                    if cmp == 0:
                        less.append(t)
                    elif cmp == 1:
                        greater.append(t)
                return quicksort(greater) + [pivot] + quicksort(less)
        return quicksort(texts_list)
        # most concrete -> lest concrete

    def rank_idx(self, texts_list): # input a list of texts
        def quicksort(arr):
            if len(arr) <= 1:
                return arr
            else:
                pivot = arr[0]
                less = []
                greater = []
                for t in arr[1:]:
                    cmp = self.compare(texts_list[pivot], texts_list[t])
                    if cmp == 0:
                        less.append(t)
                    elif cmp == 1:
                        greater.append(t)
                return quicksort(greater) + [pivot] + quicksort(less)
        return quicksort(list(range(len(texts_list))))
    
    def rank_idx_conpletely(self, texts_list):
        n = len(texts_list)
        texts_idx = list(range(n))
        scores = [[0] * n for _ in range(n)]
        self_score = [0] * n
        for i in texts_idx:
            scores[i][i] = self.compare_logits(texts_list[i], texts_list[i])
            self_score[i] = scores[i][i]
            for j in texts_idx:
                if j < i:
                    scores[i][j] = 1 - scores[j][i]
                    continue
                if j == i:
                    continue
                scores[i][j] = self.compare_logits(texts_list[i], texts_list[j])
        # average score is, smaller is more concrete
        average_score = [ sum(s)/len(s) for s in scores]
        output_score = [ a + 0.5 - s for a, s in zip(average_score, self_score)]       
        sorted_indices = sorted(range(len(output_score)), key=lambda x: output_score[x])    
        return sorted_indices

    def rank_idx_conpletely_wlogits(self, texts_list, logger=None):
        n = len(texts_list)
        texts_idx = list(range(n))
        scores = [[0] * n for _ in range(n)]
        self_score = [0] * n
        for i in texts_idx:
            scores[i][i] = self.compare_logits(texts_list[i], texts_list[i])
            self_score[i] = scores[i][i]
            for j in texts_idx:
                if j < i:
                    scores[i][j] = 1 - scores[j][i]
                    continue
                if j == i:
                    continue
                scores[i][j] = self.compare_logits(texts_list[i], texts_list[j])                
        # average score is, smaller is more concrete
        average_score = [ sum(s)/len(s) for s in scores]
        output_score = [ a + 0.5 - s for a, s in zip(average_score, self_score)]     
        sorted_indices = sorted(range(len(output_score)), key=lambda x: output_score[x])   
        return sorted_indices, output_score
    
    def compare_w_neighbors(self, t, cand):
        score = 0.0
        for c in cand:
            score += self.compare_logits(t, c)
        score /= len(cand)
        return score
text_ex_1 = "The Duke then focused on securing his power and looking to future threats. The Duke eventually turned his attention to acquiring Tuscany but struggled."
text_ex_2 = "Lord Bacon mentioned his book \"The History of Henry VII,\" in the conversation noting that King Charles had conquered Naples without resistance, implying that the conquest was like a dream."

ranker = Ranker()
output = ranker.compare(text_ex_1, text_ex_2) # it is equvilant to (text_ex_2, text_ex_1)
print(f"Output Binary = {output}")
if output:
    print("The second text is more concrete.")
else:
    print("The first text is more concrete.")

output_logits = ranker.compare_logits(text_ex_1, text_ex_2)
print(f"Output Logits = {output_logits:.4f}")

For more details on the evaluator usage (e.g., pacing planning and control in generation) and training process, please refer to our paper!