|
import itertools |
|
import re |
|
import spacy |
|
import json |
|
import evaluate |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel |
|
from unlimiformer import Unlimiformer, UnlimiformerArguments |
|
import torch |
|
|
|
from utils import * |
|
from celebbot import CelebBot |
|
|
|
QA_MODEL_ID = "google/flan-t5-xl" |
|
SENTTR_MODEL_ID = "sentence-transformers/all-mpnet-base-v2" |
|
celeb_names = ["Cate Blanchett", "David Beckham", "Emma Watson", "Lady Gaga", "Madonna", "Mark Zuckerberg"] |
|
|
|
USE_UNLIMIFORMER = True |
|
TOP_K = 8 |
|
|
|
celeb_data = get_celeb_data("data.json") |
|
references = [val['answers'] for key, val in list(celeb_data.items()) if key in celeb_names] |
|
references = list(itertools.chain.from_iterable(references)) |
|
predictions = [] |
|
|
|
device = 'cpu' |
|
QA_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_ID) |
|
QA_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_ID) |
|
if USE_UNLIMIFORMER: |
|
defaults = UnlimiformerArguments() |
|
unlimiformer_kwargs = { |
|
'layer_begin': defaults.layer_begin, |
|
'layer_end': defaults.layer_end, |
|
'unlimiformer_head_num': defaults.unlimiformer_head_num, |
|
'exclude_attention': defaults.unlimiformer_exclude, |
|
'chunk_overlap': defaults.unlimiformer_chunk_overlap, |
|
'model_encoder_max_len': defaults.unlimiformer_chunk_size, |
|
'verbose': defaults.unlimiformer_verbose, 'tokenizer': QA_tokenizer, |
|
'unlimiformer_training': defaults.unlimiformer_training, |
|
'use_datastore': defaults.use_datastore, |
|
'flat_index': defaults.flat_index, |
|
'test_datastore': defaults.test_datastore, |
|
'reconstruct_embeddings': defaults.reconstruct_embeddings, |
|
'gpu_datastore': defaults.gpu_datastore, |
|
'gpu_index': defaults.gpu_index |
|
} |
|
QA_model =Unlimiformer.convert_model(QA_model, **unlimiformer_kwargs).to(device) |
|
else: |
|
QA_model = QA_model.to(device) |
|
|
|
sentTr_tokenizer = AutoTokenizer.from_pretrained(SENTTR_MODEL_ID) |
|
sentTr_model = AutoModel.from_pretrained(SENTTR_MODEL_ID).to(device) |
|
|
|
for celeb_name in celeb_names: |
|
gender = celeb_data[celeb_name]["gender"] |
|
if celeb_name == "Madonna": |
|
name = "Madonna-American-singer-and-actress" |
|
elif celeb_name == "Anne Hathaway": |
|
name = "Anne-Hathaway-American-actress" |
|
else: |
|
name="-".join(celeb_name.split(" ")) |
|
knowledge = get_article(f"https://www.britannica.com/biography/{name}") |
|
|
|
spacy_model = spacy.load("en_core_web_lg") |
|
knowledge_sents = [i.text.strip() for i in spacy_model(knowledge).sents] |
|
|
|
ai = CelebBot(celeb_name, gender, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents, top_k=TOP_K) |
|
for q in celeb_data[celeb_name]["questions"]: |
|
ai.text = q |
|
response = ai.question_answer() |
|
print("response:", response) |
|
predictions.append(response) |
|
|
|
file = open('predictions.txt','w') |
|
for prediction in predictions: |
|
file.write(prediction+"\n") |
|
file.close() |
|
|
|
bleu = evaluate.load("bleu") |
|
results = bleu.compute(predictions=predictions, references=references, max_order=4) |
|
print(f"BLEU: {round(results['bleu'], 2)}") |
|
|
|
meteor = evaluate.load("meteor") |
|
results = meteor.compute(predictions=predictions, references=references) |
|
print(f"METEOR: {round(results['meteor'], 2)}") |
|
|
|
rouge = evaluate.load("rouge") |
|
results = rouge.compute(predictions=predictions, references=references) |
|
print(f"ROUGE: {round(results['rougeL'], 2)}") |
|
|
|
bertscore = evaluate.load("bertscore") |
|
results = bertscore.compute(predictions=predictions, references=references, rescale_with_baseline=True, lang="en") |
|
print(f"F1: {round(sum(results['f1'])/len(results['f1']), 2)}") |