Spaces:
Sleeping
Sleeping
import os.path as osp | |
from mmengine.fileio import load | |
from tabulate import tabulate | |
class BaseWeightList: | |
"""Class for generating model list in markdown format. | |
Args: | |
dataset_list (list[str]): List of dataset names. | |
table_header (list[str]): List of table header. | |
msg (str): Message to be displayed. | |
task_abbr (str): Abbreviation of task name. | |
metric_name (str): Metric name. | |
""" | |
base_url: str = 'https://github.com/open-mmlab/mmocr/blob/1.x/' | |
table_cfg: dict = dict( | |
tablefmt='pipe', floatfmt='.2f', numalign='right', stralign='center') | |
dataset_list: list | |
table_header: list | |
msg: str | |
task_abbr: str | |
metric_name: str | |
def __init__(self): | |
data = (d + f' ({self.metric_name})' for d in self.dataset_list) | |
self.table_header = ['模型', 'README', *data] | |
def _get_model_info(self, task_name: str): | |
meta_indexes = load('../../model-index.yml') | |
for meta_path in meta_indexes['Import']: | |
meta_path = osp.join('../../', meta_path) | |
metainfo = load(meta_path) | |
collection2md = {} | |
for item in metainfo['Collections']: | |
url = self.base_url + item['README'] | |
collection2md[item['Name']] = f'[链接]({url})' | |
for item in metainfo['Models']: | |
if task_name not in item['Config']: | |
continue | |
name = f'`{item["Name"]}`' | |
if item.get('Alias', None): | |
if isinstance(item['Alias'], str): | |
item['Alias'] = [item['Alias']] | |
aliases = [f'`{alias}`' for alias in item['Alias']] | |
aliases.append(name) | |
name = ' / '.join(aliases) | |
readme = collection2md[item['In Collection']] | |
eval_res = self._get_eval_res(item) | |
yield (name, readme, *eval_res) | |
def _get_eval_res(self, item): | |
eval_res = {k: '-' for k in self.dataset_list} | |
for res in item['Results']: | |
if res['Dataset'] in self.dataset_list: | |
eval_res[res['Dataset']] = res['Metrics'][self.metric_name] | |
return (eval_res[k] for k in self.dataset_list) | |
def gen_model_list(self): | |
content = f'\n{self.msg}\n' | |
content += '```{table}\n:class: model-summary nowrap field-list ' | |
content += 'table table-hover\n' | |
content += tabulate( | |
self._get_model_info(self.task_abbr), self.table_header, | |
**self.table_cfg) | |
content += '\n```\n' | |
return content | |
class TextDetWeightList(BaseWeightList): | |
dataset_list = ['ICDAR2015', 'CTW1500', 'Totaltext'] | |
msg = '### 文字检测' | |
task_abbr = 'textdet' | |
metric_name = 'hmean-iou' | |
class TextRecWeightList(BaseWeightList): | |
dataset_list = [ | |
'Avg', 'IIIT5K', 'SVT', 'ICDAR2013', 'ICDAR2015', 'SVTP', 'CT80' | |
] | |
msg = ('### 文字识别\n' | |
'```{note}\n' | |
'Avg 指该模型在 IIIT5K、SVT、ICDAR2013、ICDAR2015、SVTP、' | |
'CT80 上的平均结果。\n```\n') | |
task_abbr = 'textrecog' | |
metric_name = 'word_acc' | |
def _get_eval_res(self, item): | |
eval_res = {k: '-' for k in self.dataset_list} | |
avg = [] | |
for res in item['Results']: | |
if res['Dataset'] in self.dataset_list: | |
eval_res[res['Dataset']] = res['Metrics'][self.metric_name] | |
avg.append(res['Metrics'][self.metric_name]) | |
eval_res['Avg'] = sum(avg) / len(avg) | |
return (eval_res[k] for k in self.dataset_list) | |
class KIEWeightList(BaseWeightList): | |
dataset_list = ['wildreceipt'] | |
task_abbr = 'kie' | |
metric_name = 'macro_f1' | |
msg = '### 关键信息提取' | |
def gen_weight_list(): | |
content = TextDetWeightList().gen_model_list() | |
content += TextRecWeightList().gen_model_list() | |
content += KIEWeightList().gen_model_list() | |
return content | |