|
import streamlit as st |
|
from persist import persist, load_widget_state |
|
from modelcards import CardData, ModelCard |
|
from huggingface_hub import create_repo |
|
|
|
|
|
def is_float(value): |
|
try: |
|
float(value) |
|
return True |
|
except: |
|
return False |
|
|
|
def get_card(): |
|
languages=st.session_state.languages or None |
|
license=st.session_state.license or None |
|
library_name = st.session_state.library_name or None |
|
tags= [x.strip() for x in st.session_state.tags.split(',') if x.strip()] |
|
tags.append("autogenerated-modelcard") |
|
datasets= [x.strip() for x in st.session_state.datasets.split(',') if x.strip()] or None |
|
metrics=st.session_state.metrics or None |
|
model_name = st.session_state.model_name or None |
|
model_description = st.session_state.model_description or None |
|
|
|
|
|
authors = st.session_state.authors or None |
|
paper_url = st.session_state.paper_url or None |
|
github_url = st.session_state.github_url or None |
|
bibtex_citations = st.session_state.bibtex_citations or None |
|
emissions = float(st.session_state.emissions) if is_float(st.session_state.emissions) else None |
|
|
|
|
|
do_warn = False |
|
warning_msg = "Warning: The following fields are required but have not been filled in: " |
|
if not languages: |
|
warning_msg += "\n- Languages" |
|
do_warn = True |
|
if not license: |
|
warning_msg += "\n- License" |
|
do_warn = True |
|
if do_warn: |
|
st.error(warning_msg) |
|
st.stop() |
|
|
|
|
|
card_data = CardData( |
|
language=languages, |
|
license=license, |
|
library_name=library_name, |
|
tags=tags, |
|
datasets=datasets, |
|
metrics=metrics, |
|
) |
|
if emissions: |
|
card_data.co2_eq_emissions = {'emissions': emissions} |
|
|
|
card = ModelCard.from_template( |
|
card_data, |
|
template_path='template.md', |
|
model_id=model_name, |
|
|
|
model_description=model_description, |
|
license=license, |
|
authors=authors, |
|
paper_url=paper_url, |
|
github_url=github_url, |
|
bibtex_citations=bibtex_citations, |
|
emissions=emissions |
|
) |
|
return card |
|
|
|
|
|
def main(): |
|
|
|
card = get_card() |
|
card.save('current_card.md') |
|
view_raw = st.sidebar.checkbox("View Raw") |
|
if view_raw: |
|
st.text(card) |
|
else: |
|
st.markdown(card.text, unsafe_allow_html=True) |
|
|
|
with st.sidebar: |
|
with st.form("Upload to π€ Hub"): |
|
st.markdown("Use a token with write access from [here](https://hf.co/settings/tokens)") |
|
token = st.text_input("Token", type='password') |
|
repo_id = st.text_input("Repo ID") |
|
submit = st.form_submit_button('Upload to π€ Hub') |
|
|
|
if submit: |
|
if len(repo_id.split('/')) == 2: |
|
repo_url = create_repo(repo_id, exist_ok=True, token=token) |
|
card.push_to_hub(repo_id, token=token) |
|
st.success(f"Pushed the card to the repo [here]({repo_url}!") |
|
else: |
|
st.error("Repo ID invalid. It should be username/repo-name. For example: nateraw/food") |
|
|
|
|
|
if __name__ == "__main__": |
|
load_widget_state() |
|
main() |