NarrativeFactScore / src /fact /narrativefactscore.py
JihyukKim's picture
Initial commit
eaa3d8a
# Suppress annoying warnings from this issue which cannot be solved: https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md and transformers packages
import warnings
warnings.filterwarnings("ignore")
import re
import torch
import torch.nn as nn
import traceback
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from nltk import sent_tokenize
import logging
import openai
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
Timeout, APIConnectionError, InvalidRequestError)
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
stop_after_delay, wait_random_exponential, stop_after_attempt)
from .utils import break_down2scenes
from .prompt import build_fact_prompt
from .openai_api import openai_api_response
logger = logging.getLogger(__name__)
class OpenAIEmbedding:
def __init__(self, api_key, model="text-embedding-3-large"):
self.api_key = api_key
self.model = model
openai.api_key = api_key
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
ServiceUnavailableError, APIConnectionError)),
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
before_sleep=before_sleep_log(logger, logging.WARNING))
def encode(self, texts, **kwargs):
if isinstance(texts, str):
texts = [texts]
try:
response = openai.Embedding.create(
model=self.model,
input=texts,
)
# Extract embeddings from response
embeddings = [item["embedding"] for item in response["data"]]
return np.array(embeddings)
except Exception as e:
logger.error(f"Embedding API failed: {str(e)}")
return None
class NarrativeFactScore():
def __init__(self, model="gpt-4o-mini", split_type="fast", checkpoint=None, api_key=None, model_id="gpt-4"):
self.sent_model = OpenAIEmbedding(api_key=api_key)
self.split_type = split_type
self.checkpoint = checkpoint
self.api_key = api_key
self.model_id = model_id
openai.api_key = api_key
if model == "gptscore":
self.metric = GPTScore(model=self.model_id, api_key=self.api_key)
self.metric_function = self.metric.gpt_score
else:
raise ValueError("NarrativeFactScore currently only supports GPTScore")
def get_surrounding_sentences(self, sentence_array, ii):
if ii > 0 and ii < len(sentence_array) - 1:
sents = " ".join(np.array(sentence_array)[ii - 1 : ii + 1])
elif ii == 0:
sents = " ".join(np.array(sentence_array)[:2])
elif ii == len(sentence_array) - 1:
sents = " ".join(np.array(sentence_array)[ii - 1 :])
return sents
def group_into_sections(self, sentence_array, num_sent):
sectioned_sents = []
for ii in range(0, len(sentence_array), num_sent):
sectioned_sents.append(" ".join(sentence_array)[ii : ii + num_sent])
return sectioned_sents
def split_sent(self, text):
text_list = []
if self.split_type == "fast":
for t in text.split('.'):
if len(t) == 0:
continue
text_list.append(t)
return text_list
elif self.split_type == "fast_comma":
for t in re.split(r'[.,]', text):
if len(t) == 0:
continue
text_list.append(t)
return text_list
elif self.split_type == "gpt":
prompt = build_fact_prompt(
prompt_template = './templates/atomic_fact.txt',
input_text_list=[text],
)
response = openai_api_response(prompt, model=self.model_id, api_key=self.api_key)
text_list = []
for res in response.split('\n'):
text_list.append(res.strip())
return text_list
else:
return None
def score_src_hyp_long(self, srcs, hyps, kgs):
all_scores = []
all_scores_per_sent = []
all_relevant_scenes = []
all_summary_chunks = []
all_feedback_list = []
# src is a list containing source documents.
# hyps is a list containing predicted documents
total_score = 0
for global_idx, (src, hyp) in enumerate(zip(tqdm(srcs), hyps)):
src_sents = break_down2scenes(src)
# Get embeddings using OpenAI API
sentence_embeddings_src = self.sent_model.encode(src_sents)
sentence_embeddings_kg = self.sent_model.encode(kgs)
doc_scores = []
relevant_scenes = []
feedbacks = []
hyp_array = self.split_sent(hyp)
for idx, hyp_sentence in enumerate(hyp_array):
# Get embedding for hypothesis sentence
sentence_embeddings_hyp = self.sent_model.encode(hyp_sentence)
# Calculate cosine similarity
scores = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_src)[0]
scores_kg = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_kg)[0]
sorted_idxs = np.argsort(-1 * scores) # descending order
sorted_idxs_kg = np.argsort(-1 * scores_kg) # descending order
similar_src_sentences = []
similar_src_sentences_kg = []
triple = ''
for sorted_idx, ii in enumerate(sorted_idxs_kg[0:1]):
if sorted_idx == 0:
triple += f'{kgs[ii]}'
else:
triple += f', {kgs[ii]}'
for ii in sorted_idxs[0:1]:
similar_sents = src_sents[ii]
similar_src_sentences.append(similar_sents)
scores, feedback_list = self.metric_function(similar_src_sentences, [hyp_sentence for i in range(0, len(similar_src_sentences))], triple)
score = np.max(scores)
max_scene_idx = np.argmax(scores)
max_scene = similar_src_sentences[max_scene_idx]
feedback = feedback_list[max_scene_idx]
doc_scores.append(int(score))
relevant_scenes.append(max_scene)
feedbacks.append(feedback)
doc_score = np.mean(doc_scores)
all_scores_per_sent.append(doc_scores)
all_scores.append(doc_score)
all_relevant_scenes.append(relevant_scenes)
all_summary_chunks.append(hyp_array)
all_feedback_list.append(feedbacks)
total_score += doc_score
if global_idx % 100 == 99:
print(f"Document mean {global_idx+1} Score: {total_score/(global_idx+1)} Score")
return all_scores, all_scores_per_sent, all_relevant_scenes, all_summary_chunks, all_feedback_list
class GPTScore():
def __init__(self, model="gpt-4o", api_key=None, prompt='./templates/fact_score_kg.txt'):
self.max_length = 1024
self.model = model
self.api_key = api_key
self.prompt = prompt
openai.api_key = api_key
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
before_sleep=before_sleep_log(logger, logging.WARNING))
def gpt_inference(self, prompt):
prompt_messages = [{"role": "user", "content": prompt}]
try:
response = openai.ChatCompletion.create(
model=self.model,
messages=prompt_messages,
temperature=0,
api_key=self.api_key
)
response = response.choices[0].message.content
except InvalidRequestError:
response = 1
return response
def gpt_score(self, srcs, tgts, kgs, batch_size=4):
score_list = []
feedback_list = []
for i in range(len(srcs)):
src = srcs[i]
tgt = tgts[i]
prompt = build_fact_prompt(
prompt_template=self.prompt,
input_text_list=[src, kgs, tgt],
)
try:
score = self.gpt_inference(prompt)
if '1' in score:
score_list.append(float(1))
feedback_list.append('')
else:
score_list.append(float(0))
feedback_list.append(score)
except RuntimeError:
traceback.print_exc()
print(f"source: {src_list}")
print(f"target: {tgt_list}")
exit(0)
return score_list, feedback_list