RanM's picture
Create app.py
3cef660 verified
raw
history blame
3.64 kB
import os
from flask import Flask, request, jsonify
import spacy
import re
# Set environment variables for writable directories
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
# Initialize Flask app
app = Flask(__name__)
# Load the spaCy models once
nlp = spacy.load("en_core_web_sm")
nlp_coref = spacy.load("en_coreference_web_trf")
REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
def extract_core_name(mention_text, main_characters):
words = mention_text.split()
for character in main_characters:
if character.lower() in mention_text.lower():
return character
return words[-1]
def calculate_pronoun_density(text):
doc = nlp(text)
pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS)
named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON")
return pronoun_count / max(named_entity_count, 1), named_entity_count
def resolve_coreferences_across_text(text, main_characters):
doc = nlp_coref(text)
coref_mapping = {}
for key, cluster in doc.spans.items():
if re.match(r"coref_clusters_*", key):
main_mention = cluster[0]
core_name = extract_core_name(main_mention.text, main_characters)
if core_name in main_characters:
for mention in cluster:
for token in mention:
if token.text in REPLACE_PRONOUNS:
core_name_final = core_name if token.text.istitle() else core_name.lower()
coref_mapping[token.i] = core_name_final
resolved_tokens = []
current_sentence_characters = set()
current_sentence = []
for i, token in enumerate(doc):
if token.is_sent_start and current_sentence:
resolved_tokens.extend(current_sentence)
current_sentence_characters.clear()
current_sentence = []
if i in coref_mapping:
core_name = coref_mapping[i]
if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]:
current_sentence.append(core_name)
current_sentence_characters.add(core_name)
else:
current_sentence.append(token.text)
else:
current_sentence.append(token.text)
resolved_tokens.extend(current_sentence)
resolved_text = " ".join(resolved_tokens)
return remove_consecutive_duplicate_phrases(resolved_text)
def remove_consecutive_duplicate_phrases(text):
words = text.split()
i = 0
while i < len(words) - 1:
j = i + 1
while j < len(words):
if words[i:j] == words[j:j + (j - i)]:
del words[j:j + (j - i)]
else:
j += 1
i += 1
return " ".join(words)
def process_text(text, main_characters):
pronoun_density, named_entity_count = calculate_pronoun_density(text)
min_named_entities = len(main_characters)
if pronoun_density > 0:
return resolve_coreferences_across_text(text, main_characters)
else:
return text
# API endpoint to handle coreference resolution
@app.route('/predict', methods=['POST'])
def predict():
data = request.json
text = data.get('text')
main_characters = data.get('main_characters')
resolved_text = process_text(text, main_characters.split(","))
return jsonify({"resolved_text": resolved_text})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=int(os.getenv("PORT", 7860)))