File size: 4,004 Bytes
3cef660
36069fa
80dd1f5
36069fa
3cef660
 
 
 
 
 
 
36069fa
 
3cef660
80dd1f5
 
 
 
 
 
 
 
 
3cef660
 
 
 
 
 
36069fa
 
 
 
3cef660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36069fa
 
 
 
 
 
3cef660
 
80dd1f5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import spacy
import re

# Set environment variables for writable directories
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'

# Initialize FastAPI app
app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust the origins as needed
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load the spaCy models once
nlp = spacy.load("en_core_web_sm")
nlp_coref = spacy.load("en_coreference_web_trf")

REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}

class CorefRequest(BaseModel):
    text: str
    main_characters: str

def extract_core_name(mention_text, main_characters):
    words = mention_text.split()
    for character in main_characters:
        if character.lower() in mention_text.lower():
            return character
    return words[-1]

def calculate_pronoun_density(text):
    doc = nlp(text)
    pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS)
    named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON")
    return pronoun_count / max(named_entity_count, 1), named_entity_count

def resolve_coreferences_across_text(text, main_characters):
    doc = nlp_coref(text)
    coref_mapping = {}
    for key, cluster in doc.spans.items():
        if re.match(r"coref_clusters_*", key):
            main_mention = cluster[0]
            core_name = extract_core_name(main_mention.text, main_characters)
            if core_name in main_characters:
                for mention in cluster:
                    for token in mention:
                        if token.text in REPLACE_PRONOUNS:
                            core_name_final = core_name if token.text.istitle() else core_name.lower()
                            coref_mapping[token.i] = core_name_final
    resolved_tokens = []
    current_sentence_characters = set()
    current_sentence = []
    for i, token in enumerate(doc):
        if token.is_sent_start and current_sentence:
            resolved_tokens.extend(current_sentence)
            current_sentence_characters.clear()
            current_sentence = []
        if i in coref_mapping:
            core_name = coref_mapping[i]
            if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]:
                current_sentence.append(core_name)
                current_sentence_characters.add(core_name)
            else:
                current_sentence.append(token.text)
        else:
            current_sentence.append(token.text)
    resolved_tokens.extend(current_sentence)
    resolved_text = " ".join(resolved_tokens)
    return remove_consecutive_duplicate_phrases(resolved_text)

def remove_consecutive_duplicate_phrases(text):
    words = text.split()
    i = 0
    while i < len(words) - 1:
        j = i + 1
        while j < len(words):
            if words[i:j] == words[j:j + (j - i)]:
                del words[j:j + (j - i)]
            else:
                j += 1
        i += 1
    return " ".join(words)

def process_text(text, main_characters):
    pronoun_density, named_entity_count = calculate_pronoun_density(text)
    min_named_entities = len(main_characters)
    if pronoun_density > 0:
        return resolve_coreferences_across_text(text, main_characters)
    else:
        return text

@app.post("/predict")
async def predict(coref_request: CorefRequest):
    resolved_text = process_text(coref_request.text, coref_request.main_characters.split(","))
    if resolved_text:
        return {"resolved_text": resolved_text}
    raise HTTPException(status_code=400, detail="Coreference resolution failed")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))