from typing import Dict from PIL import ImageFont TPL_DEP_WORDS = """ {text} {tag} """ TPL_DEP_SVG = """ {content} """ TPL_DEP_ARCS = """ {label} """ def get_pil_text_size(text, font_size, font_name): font = ImageFont.truetype(font_name, font_size) size = font.getsize(text) return size def render_arrow( label: str, start: int, end: int, direction: str, i: int ) -> str: """Render individual arrow. label (str): Dependency label. start (int): Index of start word. end (int): Index of end word. direction (str): Arrow direction, 'left' or 'right'. i (int): Unique ID, typically arrow index. RETURNS (str): Rendered SVG markup. """ arc = get_arc(start + 10, 50, 5, end + 10) arrowhead = get_arrowhead(direction, start + 10, 50, end + 10) label_side = "right" if direction == "rtl" else "left" return TPL_DEP_ARCS.format( id=0, i=0, stroke=2, head=arrowhead, label=label, label_side=label_side, arc=arc, ) def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str: """Render individual arc. x_start (int): X-coordinate of arrow start point. y (int): Y-coordinate of arrow start and end point. y_curve (int): Y-corrdinate of Cubic Bézier y_curve point. x_end (int): X-coordinate of arrow end point. RETURNS (str): Definition of the arc path ('d' attribute). """ template = "M{x},{y} C{x},{c} {e},{c} {e},{y}" return template.format(x=x_start, y=y, c=y_curve, e=x_end) def get_arrowhead(direction: str, x: int, y: int, end: int) -> str: """Render individual arrow head. direction (str): Arrow direction, 'left' or 'right'. x (int): X-coordinate of arrow start point. y (int): Y-coordinate of arrow start and end point. end (int): X-coordinate of arrow end point. RETURNS (str): Definition of the arrow head path ('d' attribute). """ arrow_width = 6 if direction == "left": p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2) else: p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2) return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}" def render_sentence_custom(unmatched_list: Dict, nlp): arcs_svg = [] doc = nlp(unmatched_list["sentence"]) x_value_counter = 10 index_counter = 0 svg_words = [] words_under_arc = [] direction_current = "rtl" if unmatched_list["cur_word_index"] < unmatched_list["target_word_index"]: min_index = unmatched_list["cur_word_index"] max_index = unmatched_list["target_word_index"] direction_current = "left" else: max_index = unmatched_list["cur_word_index"] min_index = unmatched_list["target_word_index"] for i, token in enumerate(doc): word = str(token) word = word + " " pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0] svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70)) if min_index <= index_counter <= max_index: words_under_arc.append(x_value_counter) if index_counter < max_index - 1: x_value_counter += 50 index_counter += 1 x_value_counter += pixel_x_length + 4 arcs_svg.append(render_arrow(unmatched_list['dep'], words_under_arc[0], words_under_arc[-1], direction_current, i)) content = "".join(svg_words) + "".join(arcs_svg) full_svg = TPL_DEP_SVG.format( id=0, width=1200, # 600 height=75, # 125 color="#00000", bg="#ffffff", font="Arial", content=content, dir="ltr", lang="en", ) return full_svg