Benjamin Bossan
Fix bug with parsing img when val is no string
e83e18b
raw
history blame
4.26 kB
"""Utility functions for the app"""
from __future__ import annotations
import base64
import os
import re
from dataclasses import dataclass
from pathlib import Path
from skops import card
from skops.card._model_card import PlotSection, Section
PAT_MD_IMG = re.compile(
r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))'
)
def get_rendered_model_card(model_card: card.Card, hf_path: str) -> str:
# This is a bit hacky:
# As a space, the model card is created in a temporary hf_path directory,
# which is where all the files are put. So e.g. if a figure is added, it is
# found at /tmp/skops-jtyqdgk3/fig.png. However, when the model card is is
# actually used, we don't want that, since there, the files will be in the
# cwd. Therefore, we remove the tmp directory everywhere we find it in the
# file.
if not hf_path.endswith(os.path.sep):
hf_path += os.path.sep
rendered = model_card.render()
rendered = rendered.replace(hf_path, "")
return rendered
def process_card_for_rendering(rendered: str) -> tuple[str, str]:
idx = rendered[1:].index("\n---") + 1
metadata = rendered[3:idx]
rendered = rendered[idx + 4 :] # noqa: E203
# below is a hack to display the images in streamlit
# https://discuss.streamlit.io/t/image-in-markdown/13274/10 The problem is
# that streamlit does not display images in markdown, so we need to replace
# them with html. However, we only want that in the rendered markdown, not
# in the card that is produced for the hub
def markdown_images(markdown):
# example image markdown:
# ![Test image](images/test.png "Alternate text")
images = PAT_MD_IMG.findall(markdown)
return images
def img_to_bytes(img_path):
img_bytes = Path(img_path).read_bytes()
encoded = base64.b64encode(img_bytes).decode()
return encoded
def img_to_html(img_path, img_alt):
img_format = img_path.split(".")[-1]
img_html = (
f'<img src="data:image/{img_format.lower()};'
f'base64,{img_to_bytes(img_path)}" '
f'alt="{img_alt}" '
'style="max-width: 100%;">'
)
return img_html
def markdown_insert_images(markdown):
images = markdown_images(markdown)
for image in images:
image_markdown = image[0]
image_alt = image[1]
image_path = image[2]
markdown = markdown.replace(
image_markdown, img_to_html(image_path, image_alt)
)
return markdown
rendered_with_img = markdown_insert_images(rendered)
return metadata, rendered_with_img
@dataclass(frozen=True)
class SectionInfo:
return_key: str
title: str
content: str
is_fig: bool
level: int
def iterate_key_section_content(
data: dict[str, Section],
parent_section: str = "",
parent_keys: list[str] | None = None,
level: int = 0,
) -> SectionInfo:
parent_keys = parent_keys or []
for key, val in data.items():
if parent_section:
title = "/".join((parent_section, val.title))
else:
title = val.title
if not getattr(val, "visible", True):
continue
return_key = key if not parent_keys else "/".join(parent_keys + [key])
content = val.content
is_fig = getattr(val, "is_fig", False)
if isinstance(val.content, str):
img_match = PAT_MD_IMG.match(val.content)
if img_match: # image section found in parsed model card
is_fig = True
img_title = img_match.groupdict()["image_title"]
img_path = img_match.groupdict()["image_path"]
content = PlotSection(alt_text=img_title, path=img_path)
yield SectionInfo(
return_key=return_key,
title=title,
content=content,
is_fig=is_fig,
level=level,
)
if val.subsections:
yield from iterate_key_section_content(
val.subsections,
parent_section=title,
parent_keys=parent_keys + [key],
level=level + 1,
)