File size: 3,753 Bytes
3cef660
36069fa
 
3cef660
 
 
 
 
 
 
36069fa
 
3cef660
 
 
 
 
 
 
36069fa
 
 
 
3cef660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36069fa
 
 
 
 
 
3cef660
 
36069fa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import spacy
import re

# Set environment variables for writable directories
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'

# Initialize FastAPI app
app = FastAPI()

# Load the spaCy models once
nlp = spacy.load("en_core_web_sm")
nlp_coref = spacy.load("en_coreference_web_trf")

REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}

class CorefRequest(BaseModel):
    text: str
    main_characters: str

def extract_core_name(mention_text, main_characters):
    words = mention_text.split()
    for character in main_characters:
        if character.lower() in mention_text.lower():
            return character
    return words[-1]

def calculate_pronoun_density(text):
    doc = nlp(text)
    pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS)
    named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON")
    return pronoun_count / max(named_entity_count, 1), named_entity_count

def resolve_coreferences_across_text(text, main_characters):
    doc = nlp_coref(text)
    coref_mapping = {}
    for key, cluster in doc.spans.items():
        if re.match(r"coref_clusters_*", key):
            main_mention = cluster[0]
            core_name = extract_core_name(main_mention.text, main_characters)
            if core_name in main_characters:
                for mention in cluster:
                    for token in mention:
                        if token.text in REPLACE_PRONOUNS:
                            core_name_final = core_name if token.text.istitle() else core_name.lower()
                            coref_mapping[token.i] = core_name_final
    resolved_tokens = []
    current_sentence_characters = set()
    current_sentence = []
    for i, token in enumerate(doc):
        if token.is_sent_start and current_sentence:
            resolved_tokens.extend(current_sentence)
            current_sentence_characters.clear()
            current_sentence = []
        if i in coref_mapping:
            core_name = coref_mapping[i]
            if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]:
                current_sentence.append(core_name)
                current_sentence_characters.add(core_name)
            else:
                current_sentence.append(token.text)
        else:
            current_sentence.append(token.text)
    resolved_tokens.extend(current_sentence)
    resolved_text = " ".join(resolved_tokens)
    return remove_consecutive_duplicate_phrases(resolved_text)

def remove_consecutive_duplicate_phrases(text):
    words = text.split()
    i = 0
    while i < len(words) - 1:
        j = i + 1
        while j < len(words):
            if words[i:j] == words[j:j + (j - i)]:
                del words[j:j + (j - i)]
            else:
                j += 1
        i += 1
    return " ".join(words)

def process_text(text, main_characters):
    pronoun_density, named_entity_count = calculate_pronoun_density(text)
    min_named_entities = len(main_characters)
    if pronoun_density > 0:
        return resolve_coreferences_across_text(text, main_characters)
    else:
        return text

@app.post("/predict")
async def predict(coref_request: CorefRequest):
    resolved_text = process_text(coref_request.text, coref_request.main_characters.split(","))
    if resolved_text:
        return {"resolved_text": resolved_text}
    raise HTTPException(status_code=400, detail="Coreference resolution failed")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))