Benjamin Bossan commited on
Commit
273184e
1 Parent(s): d8d447c

When loading cards, recognize image sections

Browse files
Files changed (1) hide show
  1. utils.py +17 -6
utils.py CHANGED
@@ -9,7 +9,12 @@ from dataclasses import dataclass
9
  from pathlib import Path
10
 
11
  from skops import card
12
- from skops.card._model_card import Section
 
 
 
 
 
13
 
14
 
15
  def get_rendered_model_card(model_card: card.Card, hf_path: str) -> str:
@@ -42,10 +47,7 @@ def process_card_for_rendering(rendered: str) -> tuple[str, str]:
42
  def markdown_images(markdown):
43
  # example image markdown:
44
  # ![Test image](images/test.png "Alternate text")
45
- images = re.findall(
46
- r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))',
47
- markdown,
48
- )
49
  return images
50
 
51
  def img_to_bytes(img_path):
@@ -106,11 +108,20 @@ def iterate_key_section_content(
106
  continue
107
 
108
  return_key = key if not parent_keys else "/".join(parent_keys + [key])
 
 
109
  is_fig = getattr(val, "is_fig", False)
 
 
 
 
 
 
 
110
  yield SectionInfo(
111
  return_key=return_key,
112
  title=title,
113
- content=val.content,
114
  is_fig=is_fig,
115
  level=level,
116
  )
 
9
  from pathlib import Path
10
 
11
  from skops import card
12
+ from skops.card._model_card import PlotSection, Section
13
+
14
+
15
+ PAT_MD_IMG = re.compile(
16
+ r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))'
17
+ )
18
 
19
 
20
  def get_rendered_model_card(model_card: card.Card, hf_path: str) -> str:
 
47
  def markdown_images(markdown):
48
  # example image markdown:
49
  # ![Test image](images/test.png "Alternate text")
50
+ images = PAT_MD_IMG.findall(markdown)
 
 
 
51
  return images
52
 
53
  def img_to_bytes(img_path):
 
108
  continue
109
 
110
  return_key = key if not parent_keys else "/".join(parent_keys + [key])
111
+ content = val.content
112
+
113
  is_fig = getattr(val, "is_fig", False)
114
+ img_match = PAT_MD_IMG.match(val.content)
115
+ if img_match: # image section found in parsed model card
116
+ is_fig = True
117
+ img_title = img_match.groupdict()["image_title"]
118
+ img_path = img_match.groupdict()["image_path"]
119
+ content = PlotSection(alt_text=img_title, path=img_path)
120
+
121
  yield SectionInfo(
122
  return_key=return_key,
123
  title=title,
124
+ content=content,
125
  is_fig=is_fig,
126
  level=level,
127
  )