File size: 4,686 Bytes
7a3d7a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
from typing import Dict, Any, Optional, List
import re
from abc import ABC, abstractmethod

from huggingface_hub import (ModelCard, comment_discussion,
                             create_discussion, get_discussion_details,
                             get_repo_discussions)
import markdown
from bs4 import BeautifulSoup
from tabulate import tabulate
from difflib import SequenceMatcher

KEY = os.environ.get("KEY")


def similar(a, b):
    """Check similarity of two sequences"""
    return SequenceMatcher(None, a, b).ratio()


class ComplianceCheck(ABC):
    def __init__(self, name):
        self.name = name

    @abstractmethod
    def check(self, card: BeautifulSoup) -> bool:
        raise NotImplementedError


class ModelProviderIdentityCheck(ComplianceCheck):
    def __init__(self):
        super().__init__("Identity and Contact Details")

    def check(self, card: BeautifulSoup):
        developed_by_li = card.findAll(text=re.compile("Developed by"))[0].parent.parent
        developed_by = list(developed_by_li.children)[1].text.strip()

        if developed_by == "[More Information Needed]":
            return False
        else:
            return True


class IntendedPurposeCheck(ComplianceCheck):
    def __init__(self):
        super().__init__("Intended Purpose")

    def check(self, card: BeautifulSoup):

        # direct_use = card.find_all("h2", text="Direct Use")[0]
        #
        # if developed_by == "[More Information Needed]":
        #     return False
        # else:
        return False


compliance_checks = [
    ModelProviderIdentityCheck(),
    IntendedPurposeCheck()
    # "General Limitations",
    # "Computational and Hardware Requirements",
    # "Carbon Emissions"
]


def parse_webhook_post(data: Dict[str, Any]) -> Optional[str]:
    event = data["event"]
    if event["scope"] != "repo":
        return None
    repo = data["repo"]
    repo_name = repo["name"]
    repo_type = repo["type"]
    if repo_type != "model":
        raise ValueError("Incorrect repo type.")
    return repo_name


def check_compliance(comp_checks: List[ComplianceCheck], card: BeautifulSoup) -> Dict[str, bool]:
    return {c.name: c.check(card) for c in comp_checks}


def run_compliance_check(repo_name):
    card_data: ModelCard = ModelCard.load(repo_id_or_path=repo_name)
    card_html = markdown.markdown(card_data.content)
    card_soup = BeautifulSoup(card_html, features="html.parser")
    compliance_results = check_compliance(compliance_checks, card_soup)

    return compliance_results


def create_metadata_breakdown_table(compliance_check_dictionary):
    data = {k: v for k, v in compliance_check_dictionary.items()}
    metadata_fields_column = list(data.keys())
    metadata_values_column = list(data.values())
    table_data = list(zip(metadata_fields_column, metadata_values_column))
    return tabulate(
        table_data, tablefmt="github", headers=("Compliance Check", "Present")
    )


def create_markdown_report(
    desired_metadata_dictionary, repo_name, update: bool = False
):
    report = f"""# Model Card Regulatory Compliance report card {"(updated)" if update else ""}
    \n
This is an automatically produced model card regulatory compliance report card for {repo_name}.
This report is meant as a POC!
    \n 
## Breakdown of metadata fields for your model
\n
{create_metadata_breakdown_table(desired_metadata_dictionary)}
\n
    """
    return report


def create_or_update_report(compliance_check, repo_name):
    report = create_markdown_report(
        compliance_check, repo_name, update=False
    )
    repo_discussions = get_repo_discussions(
        repo_name,
        repo_type="model",
    )
    for discussion in repo_discussions:
        if (
            discussion.title == "Metadata Report Card" and discussion.status == "open"
        ):  # An existing open report card thread
            discussion_details = get_discussion_details(
                repo_name, discussion.num, repo_type="model"
            )
            last_comment = discussion_details.events[-1].content
            if similar(report, last_comment) <= 0.999:
                report = create_markdown_report(
                    compliance_check,
                    repo_name,
                    update=True,
                )
                comment_discussion(
                    repo_name,
                    discussion.num,
                    comment=report,
                    repo_type="model",
                )
            return True
    create_discussion(
        repo_name,
        "Model Card Regulatory Compliance Report Card",
        description=report,
        repo_type="model",
    )
    return True