Spaces:
Runtime error
Runtime error
import itertools | |
import re | |
import spacy | |
import json | |
import evaluate | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel | |
from utils import * | |
from celebbot import CelebBot | |
DEBUG=True | |
QA_MODEL_ID = "google/flan-t5-large" | |
SENTTR_MODEL_ID = "sentence-transformers/all-mpnet-base-v2" | |
def evaluate_system(): | |
with open("data.json") as json_file: | |
celeb_data = json.load(json_file) | |
references = [val['answers'] for val in list(celeb_data.values())] | |
references = list(itertools.chain.from_iterable(references)) | |
predictions = [] | |
QA_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_ID) | |
QA_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_ID) | |
sentTr_tokenizer = AutoTokenizer.from_pretrained(SENTTR_MODEL_ID) | |
sentTr_model = AutoModel.from_pretrained(SENTTR_MODEL_ID) | |
for name in list(celeb_data.keys()): | |
gender = celeb_data[name]["gender"] | |
knowledge = celeb_data[name]["knowledge"] | |
lname = name.split(" ")[-1] | |
lname_regex = re.compile(rf'\b({lname})\b') | |
name_regex = re.compile(rf'\b({name})\b') | |
lnames = lname+"βs" if not lname.endswith("s") else lname+"β" | |
lnames_regex = re.compile(rf'\b({lnames})\b') | |
names = name+"βs" if not name.endswith("s") else name+"β" | |
names_regex = re.compile(rf'\b({names})\b') | |
if gender == "M": | |
knowledge = re.sub(he_regex, "I", knowledge) | |
knowledge = re.sub(his_regex, "my", knowledge) | |
elif gender == "F": | |
knowledge = re.sub(she_regex, "I", knowledge) | |
knowledge = re.sub(her_regex, "my", knowledge) | |
knowledge = re.sub(names_regex, "my", knowledge) | |
knowledge = re.sub(lnames_regex, "my", knowledge) | |
knowledge = re.sub(name_regex, "I", knowledge) | |
knowledge = re.sub(lname_regex, "I", knowledge) | |
spacy_model = spacy.load("en_core_web_sm") | |
knowledge_sents = [i.text.strip() for i in spacy_model(knowledge).sents] | |
ai = CelebBot(name, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents) | |
for q in celeb_data[name]["questions"]: | |
if not DEBUG: | |
ai.speech_to_text() | |
else: | |
# ai.text = input("Your question: ") | |
pass | |
ai.text = q | |
if ai.text != "": | |
print("me --> ", ai.text) | |
predictions.append(ai.question_answer()) | |
if not DEBUG: | |
ai.text_to_speech() | |
ai.text = "" | |
file = open('predictions.txt','w') | |
for prediction in predictions: | |
file.write(prediction+"\n") | |
file.close() | |
bleu = evaluate.load("bleu") | |
results = bleu.compute(predictions=predictions, references=references, max_order=4) | |
print(f"BLEU: {round(results['bleu'], 2)}") | |
meteor = evaluate.load("meteor") | |
results = meteor.compute(predictions=predictions, references=references) | |
print(f"METEOR: {round(results['meteor'], 2)}") | |
rouge = evaluate.load("rouge") | |
results = rouge.compute(predictions=predictions, references=references) | |
print(f"ROUGE: {round(results['rougeL'], 2)}") | |
if __name__ == "__main__": | |
evaluate_system() | |