import os from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import spacy import re from typing import List # Set environment variables for writable directories os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache' os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib' # Initialize FastAPI app app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust the origins as needed allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load the spaCy models once nlp = spacy.load("en_core_web_sm") nlp_coref = spacy.load("en_coreference_web_trf") REPLACE_PRONOUNS = {"he","his", "she", "her", "they", "He", "His", "She", "Her", "They"} class CorefRequest(BaseModel): text: str main_characters: List[str] def extract_core_name(mention_text, main_characters): words = mention_text.split() for character in main_characters: if character.lower() in mention_text.lower(): return character return words[-1] def calculate_pronoun_density(text): doc = nlp(text) pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS) named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON") return pronoun_count / max(named_entity_count, 1), named_entity_count def resolve_coreferences_across_text(text, main_characters): doc = nlp_coref(text) coref_mapping = {} for key, cluster in doc.spans.items(): if re.match(r"coref_clusters_*", key): main_mention = cluster[0] core_name = extract_core_name(main_mention.text, main_characters) if core_name in main_characters: for mention in cluster: for token in mention: if token.text in REPLACE_PRONOUNS: core_name_final = core_name if token.text.istitle() else core_name.lower() coref_mapping[token.i] = core_name_final resolved_tokens = [] current_sentence_characters = set() current_sentence = [] for i, token in enumerate(doc): if token.is_sent_start and current_sentence: resolved_tokens.extend(current_sentence) current_sentence_characters.clear() current_sentence = [] if i in coref_mapping: core_name = coref_mapping[i] if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]: current_sentence.append(core_name) current_sentence_characters.add(core_name) else: current_sentence.append(token.text) else: current_sentence.append(token.text) resolved_tokens.extend(current_sentence) resolved_text = " ".join(resolved_tokens) return remove_consecutive_duplicate_phrases(resolved_text) def remove_consecutive_duplicate_phrases(text): words = text.split() i = 0 while i < len(words) - 1: j = i + 1 while j < len(words): if words[i:j] == words[j:j + (j - i)]: del words[j:j + (j - i)] else: j += 1 i += 1 return " ".join(words) def process_text(text, main_characters): pronoun_density, named_entity_count = calculate_pronoun_density(text) min_named_entities = len(main_characters) if pronoun_density > 0: return resolve_coreferences_across_text(text, main_characters) else: return text @app.post("/predict") async def predict(coref_request: CorefRequest): resolved_text = process_text(coref_request.text, coref_request.main_characters) if resolved_text: return {"resolved_text": resolved_text} raise HTTPException(status_code=400, detail="Coreference resolution failed") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))