|
import json
|
|
from bs4 import BeautifulSoup
|
|
import re
|
|
from tqdm import tqdm
|
|
import sys
|
|
import question_categorizer as qc
|
|
import numpy as np
|
|
from question_categorizer import TextClassificationModel
|
|
|
|
qc_model = qc.TextClassificationModel.load_model("models/categorizer")
|
|
|
|
categories = ['Geography', 'Religion', 'Philosophy', 'Trash','Mythology', 'Literature','Science', 'Social Science', 'History', 'Current Events', 'Fine Arts']
|
|
|
|
def remove_newline(string):
|
|
return re.sub('\n+', ' ', string)
|
|
|
|
def clean_text(text, answer):
|
|
|
|
text = re.sub(r'<.*?>', '', text)
|
|
|
|
|
|
text = text.replace('?','.')
|
|
|
|
|
|
text = re.sub(r'[^a-zA-Z.\s-]', '', text)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
processed_answer = answer.replace('_', ' ')
|
|
|
|
|
|
processed_answer = re.sub(r'\([^)]*\)', '', processed_answer)
|
|
|
|
|
|
text = re.sub(re.escape(processed_answer), '', text, flags=re.IGNORECASE)
|
|
except Exception as e:
|
|
print("An error occurred during text cleaning:", e)
|
|
print("Text:", text)
|
|
print("Answer:", answer)
|
|
|
|
|
|
text = re.sub(r'\s+', ' ', text)
|
|
|
|
return text.strip()
|
|
|
|
def process_data():
|
|
|
|
|
|
jeopardy_data = []
|
|
|
|
wiki_files = [
|
|
]
|
|
|
|
question_files = [
|
|
"qadata.json"]
|
|
|
|
wiki_data = []
|
|
question_data = []
|
|
|
|
for file_path in wiki_files:
|
|
with open('data/' + file_path, "r") as f:
|
|
wiki_data.extend(json.load(f))
|
|
|
|
for file_path in question_files:
|
|
with open('data/' + file_path, "r") as f:
|
|
question_data.extend(json.load(f))
|
|
|
|
|
|
|
|
with open("data/training_data.json", "w") as f:
|
|
training_data = []
|
|
|
|
|
|
print("Processing Jeopardy data...")
|
|
for entry in tqdm(jeopardy_data):
|
|
question = entry["question"]
|
|
answer = str(entry["answer"])
|
|
|
|
|
|
soup = BeautifulSoup(question, 'html.parser')
|
|
clean_question = ''.join(soup.findAll(text=True, recursive=False))
|
|
|
|
question_category = []
|
|
|
|
|
|
prediction = qc_model.predict(question)
|
|
predictions = np.argwhere(prediction >= 1.5)[1]
|
|
|
|
for prediction_ind in predictions:
|
|
|
|
question_category.append(categories[prediction_ind])
|
|
|
|
question_category.append('ALL')
|
|
|
|
|
|
|
|
training_entry = {
|
|
"text": clean_question,
|
|
"answer": answer,
|
|
|
|
"category": question_category
|
|
}
|
|
|
|
training_data.append(training_entry)
|
|
|
|
|
|
print("Processing Wikipedia data...")
|
|
for entry in tqdm(wiki_data):
|
|
page = str(entry["page"])
|
|
text = entry["text"]
|
|
|
|
if(text == ""):
|
|
continue
|
|
|
|
text = remove_newline(text)
|
|
text = clean_text(text, page)
|
|
|
|
question_category = []
|
|
|
|
|
|
prediction = qc_model.predict(text)
|
|
predictions = np.argwhere(prediction >= 1.5)[1]
|
|
|
|
for prediction_ind in predictions:
|
|
|
|
question_category.append(categories[prediction_ind])
|
|
|
|
question_category.append('ALL')
|
|
|
|
|
|
|
|
training_entry = {
|
|
"text": text,
|
|
"answer": page,
|
|
|
|
"category": question_category
|
|
}
|
|
|
|
training_data.append(training_entry)
|
|
|
|
print("Processing Misc data...")
|
|
for entry in tqdm(question_data):
|
|
|
|
answer = str(entry["answer"])
|
|
text = entry["text"]
|
|
|
|
if(text == "" or answer == ""):
|
|
continue
|
|
|
|
text = remove_newline(text)
|
|
text = clean_text(text, answer)
|
|
|
|
question_category = []
|
|
|
|
|
|
try:
|
|
prediction = qc_model.predict(text)
|
|
predictions = np.argwhere(prediction >= 1.5)[1]
|
|
except:
|
|
print("answer: " + str(answer))
|
|
print("text:" + str(text))
|
|
continue
|
|
|
|
for prediction_ind in predictions:
|
|
|
|
question_category.append(categories[prediction_ind])
|
|
|
|
question_category.append('ALL')
|
|
|
|
|
|
|
|
training_entry = {
|
|
"text": text,
|
|
"answer": answer,
|
|
|
|
"category": question_category
|
|
}
|
|
|
|
training_data.append(training_entry)
|
|
|
|
|
|
|
|
json.dump(training_data, f, indent=4)
|
|
|
|
process_data()
|
|
|