File size: 4,261 Bytes
f4f6aba
 
31a1df6
 
 
dd68837
31a1df6
2481b28
31a1df6
 
dd68837
273184e
 
 
 
 
 
31a1df6
 
dd68837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a1df6
 
 
 
 
 
 
 
 
 
 
 
 
 
273184e
31a1df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2481b28
 
 
 
 
 
 
 
 
31a1df6
 
 
 
e83e18b
2481b28
31a1df6
 
 
 
 
 
 
 
825249c
31a1df6
 
 
273184e
 
31a1df6
e83e18b
 
 
 
 
 
 
273184e
2481b28
 
 
273184e
2481b28
 
 
31a1df6
 
 
 
 
 
2481b28
31a1df6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Utility functions for the app"""

from __future__ import annotations

import base64
import os
import re
from dataclasses import dataclass
from pathlib import Path

from skops import card
from skops.card._model_card import PlotSection, Section


PAT_MD_IMG = re.compile(
    r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))'
)


def get_rendered_model_card(model_card: card.Card, hf_path: str) -> str:
    # This is a bit hacky:
    # As a space, the model card is created in a temporary hf_path directory,
    # which is where all the files are put. So e.g. if a figure is added, it is
    # found at /tmp/skops-jtyqdgk3/fig.png. However, when the model card is is
    # actually used, we don't want that, since there, the files will be in the
    # cwd. Therefore, we remove the tmp directory everywhere we find it in the
    # file.
    if not hf_path.endswith(os.path.sep):
        hf_path += os.path.sep

    rendered = model_card.render()
    rendered = rendered.replace(hf_path, "")
    return rendered


def process_card_for_rendering(rendered: str) -> tuple[str, str]:
    idx = rendered[1:].index("\n---") + 1
    metadata = rendered[3:idx]
    rendered = rendered[idx + 4 :]  # noqa: E203

    # below is a hack to display the images in streamlit
    # https://discuss.streamlit.io/t/image-in-markdown/13274/10 The problem is

    # that streamlit does not display images in markdown, so we need to replace
    # them with html. However, we only want that in the rendered markdown, not
    # in the card that is produced for the hub
    def markdown_images(markdown):
        # example image markdown:
        # ![Test image](images/test.png "Alternate text")
        images = PAT_MD_IMG.findall(markdown)
        return images

    def img_to_bytes(img_path):
        img_bytes = Path(img_path).read_bytes()
        encoded = base64.b64encode(img_bytes).decode()
        return encoded

    def img_to_html(img_path, img_alt):
        img_format = img_path.split(".")[-1]
        img_html = (
            f'<img src="data:image/{img_format.lower()};'
            f'base64,{img_to_bytes(img_path)}" '
            f'alt="{img_alt}" '
            'style="max-width: 100%;">'
        )
        return img_html

    def markdown_insert_images(markdown):
        images = markdown_images(markdown)

        for image in images:
            image_markdown = image[0]
            image_alt = image[1]
            image_path = image[2]
            markdown = markdown.replace(
                image_markdown, img_to_html(image_path, image_alt)
            )
        return markdown

    rendered_with_img = markdown_insert_images(rendered)
    return metadata, rendered_with_img


@dataclass(frozen=True)
class SectionInfo:
    return_key: str
    title: str
    content: str
    is_fig: bool
    level: int


def iterate_key_section_content(
    data: dict[str, Section],
    parent_section: str = "",
    parent_keys: list[str] | None = None,
    level: int = 0,
) -> SectionInfo:
    parent_keys = parent_keys or []

    for key, val in data.items():
        if parent_section:
            title = "/".join((parent_section, val.title))
        else:
            title = val.title

        if not getattr(val, "visible", True):
            continue

        return_key = key if not parent_keys else "/".join(parent_keys + [key])
        content = val.content

        is_fig = getattr(val, "is_fig", False)
        if isinstance(val.content, str):
            img_match = PAT_MD_IMG.match(val.content)
            if img_match:  # image section found in parsed model card
                is_fig = True
                img_title = img_match.groupdict()["image_title"]
                img_path = img_match.groupdict()["image_path"]
                content = PlotSection(alt_text=img_title, path=img_path)

        yield SectionInfo(
            return_key=return_key,
            title=title,
            content=content,
            is_fig=is_fig,
            level=level,
        )

        if val.subsections:
            yield from iterate_key_section_content(
                val.subsections,
                parent_section=title,
                parent_keys=parent_keys + [key],
                level=level + 1,
            )