|
from typing import Dict, Any |
|
|
|
import spacy |
|
from PIL import ImageFont |
|
|
|
from spacy.tokens import Doc |
|
|
|
def get_pil_text_size(text, font_size, font_name): |
|
font = ImageFont.truetype(font_name, font_size) |
|
size = font.getsize(text) |
|
return size |
|
|
|
|
|
def render_arrow( |
|
label: str, start: int, end: int, direction: str, i: int |
|
) -> str: |
|
"""Render individual arrow. |
|
|
|
label (str): Dependency label. |
|
start (int): Index of start word. |
|
end (int): Index of end word. |
|
direction (str): Arrow direction, 'left' or 'right'. |
|
i (int): Unique ID, typically arrow index. |
|
RETURNS (str): Rendered SVG markup. |
|
""" |
|
TPL_DEP_ARCS = """ |
|
<g class="displacy-arrow"> |
|
<path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="red"/> |
|
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px"> |
|
<textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="red" text-anchor="middle">{label}</textPath> |
|
</text> |
|
<path class="displacy-arrowhead" d="{head}" fill="red"/> |
|
</g> |
|
""" |
|
arc = get_arc(start + 20, 50, 5, end + 20) |
|
arrowhead = get_arrowhead(direction, start + 20, 50, end + 20) |
|
label_side = "right" if direction == "rtl" else "left" |
|
return TPL_DEP_ARCS.format( |
|
id=0, |
|
i=0, |
|
stroke=2, |
|
head=arrowhead, |
|
label=label, |
|
label_side=label_side, |
|
arc=arc, |
|
) |
|
|
|
|
|
def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str: |
|
"""Render individual arc. |
|
|
|
x_start (int): X-coordinate of arrow start point. |
|
y (int): Y-coordinate of arrow start and end point. |
|
y_curve (int): Y-corrdinate of Cubic Bézier y_curve point. |
|
x_end (int): X-coordinate of arrow end point. |
|
RETURNS (str): Definition of the arc path ('d' attribute). |
|
""" |
|
template = "M{x},{y} C{x},{c} {e},{c} {e},{y}" |
|
return template.format(x=x_start, y=y, c=y_curve, e=x_end) |
|
|
|
|
|
def get_arrowhead(direction: str, x: int, y: int, end: int) -> str: |
|
"""Render individual arrow head. |
|
|
|
direction (str): Arrow direction, 'left' or 'right'. |
|
x (int): X-coordinate of arrow start point. |
|
y (int): Y-coordinate of arrow start and end point. |
|
end (int): X-coordinate of arrow end point. |
|
RETURNS (str): Definition of the arrow head path ('d' attribute). |
|
""" |
|
arrow_width = 6 |
|
if direction == "left": |
|
p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2) |
|
else: |
|
p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2) |
|
return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}" |
|
|
|
|
|
|
|
def render_sentence_custom(parsed: str): |
|
TPL_DEP_WORDS = """ |
|
<text class="displacy-token" fill="currentColor" text-anchor="start" y="{y}"> |
|
<tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan> |
|
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan> |
|
</text> |
|
""" |
|
|
|
TPL_DEP_SVG = """ |
|
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg> |
|
""" |
|
arcs_svg = [] |
|
couples = [] |
|
nlp = spacy.load('en_core_web_sm') |
|
doc = nlp(parsed) |
|
arcs = {} |
|
words = {} |
|
parsed = [parse_deps(doc)] |
|
for i, p in enumerate(parsed): |
|
arcs = p["arcs"] |
|
words = p["words"] |
|
for i, a in enumerate(arcs): |
|
if a["label"] == "amod": |
|
couples = (a["start"], a["end"]) |
|
|
|
print(couples) |
|
x_value_counter = 10 |
|
index_counter = 0 |
|
svg_words = [] |
|
coords_test = [] |
|
for i, word in enumerate(words): |
|
word = word["text"] |
|
word = word + " " |
|
pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0] |
|
svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70)) |
|
print(index_counter) |
|
if index_counter >= couples[0] and index_counter <= couples[1]: |
|
coords_test.append(x_value_counter) |
|
x_value_counter += 50 |
|
index_counter += 1 |
|
x_value_counter += pixel_x_length + 4 |
|
print(coords_test) |
|
for i, a in enumerate(arcs): |
|
if a["label"] == "amod": |
|
arcs_svg.append(render_arrow(a["label"], coords_test[0], coords_test[-1], a["dir"], i)) |
|
|
|
content = "".join(svg_words) + "".join(arcs_svg) |
|
|
|
full_svg = TPL_DEP_SVG.format( |
|
id=0, |
|
width=1975, |
|
height=574.5, |
|
color="#00000", |
|
bg="#ffffff", |
|
font="Arial", |
|
content=content, |
|
dir="ltr", |
|
lang="en", |
|
) |
|
|
|
return full_svg |
|
|
|
def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]: |
|
"""Generate dependency parse in {'words': [], 'arcs': []} format. |
|
|
|
doc (Doc): Document do parse. |
|
RETURNS (dict): Generated dependency parse keyed by words and arcs. |
|
""" |
|
doc = Doc(orig_doc.vocab).from_bytes(orig_doc.to_bytes(exclude=["user_data"])) |
|
if not doc.has_annotation("DEP"): |
|
print("WARNING") |
|
if options.get("collapse_phrases", False): |
|
with doc.retokenize() as retokenizer: |
|
for np in list(doc.noun_chunks): |
|
attrs = { |
|
"tag": np.root.tag_, |
|
"lemma": np.root.lemma_, |
|
"ent_type": np.root.ent_type_, |
|
} |
|
retokenizer.merge(np, attrs=attrs) |
|
if options.get("collapse_punct", True): |
|
spans = [] |
|
for word in doc[:-1]: |
|
if word.is_punct or not word.nbor(1).is_punct: |
|
continue |
|
start = word.i |
|
end = word.i + 1 |
|
while end < len(doc) and doc[end].is_punct: |
|
end += 1 |
|
span = doc[start:end] |
|
spans.append((span, word.tag_, word.lemma_, word.ent_type_)) |
|
with doc.retokenize() as retokenizer: |
|
for span, tag, lemma, ent_type in spans: |
|
attrs = {"tag": tag, "lemma": lemma, "ent_type": ent_type} |
|
retokenizer.merge(span, attrs=attrs) |
|
fine_grained = options.get("fine_grained") |
|
add_lemma = options.get("add_lemma") |
|
words = [ |
|
{ |
|
"text": w.text, |
|
"tag": w.tag_ if fine_grained else w.pos_, |
|
"lemma": w.lemma_ if add_lemma else None, |
|
} |
|
for w in doc |
|
] |
|
arcs = [] |
|
for word in doc: |
|
if word.i < word.head.i: |
|
arcs.append( |
|
{"start": word.i, "end": word.head.i, "label": word.dep_, "dir": "left"} |
|
) |
|
elif word.i > word.head.i: |
|
arcs.append( |
|
{ |
|
"start": word.head.i, |
|
"end": word.i, |
|
"label": word.dep_, |
|
"dir": "right", |
|
} |
|
) |
|
return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)} |
|
|
|
def get_doc_settings(doc: Doc) -> Dict[str, Any]: |
|
return { |
|
"lang": doc.lang_, |
|
"direction": doc.vocab.writing_system.get("direction", "ltr"), |
|
} |