Model_Cards_Writing_Tool / extract_code.py
Ezi Ozoani
live push
2d5ffb9
raw history blame
No virus
14.6 kB
#!/usr/bin/env python3
import re
"""
Extracts code from the file "./Libraries.ts".
(Note that "Libraries.ts", must be in the same directory as
this script).
"""
file = None
def read_file(library: str, model_name: str) -> str:
text = file
match = re.search('const ' + library + '.*', text, re.DOTALL).group()
if match:
text = match[match.index('`') + 1:match.index('`;')].replace('${model.id}', model_name)
return text
file = """
import type { ModelData } from "./Types";
/**
* Add your new library here.
*/
export enum ModelLibrary {
"adapter-transformers" = "Adapter Transformers",
"allennlp" = "allenNLP",
"asteroid" = "Asteroid",
"diffusers" = "Diffusers",
"espnet" = "ESPnet",
"fairseq" = "Fairseq",
"flair" = "Flair",
"keras" = "Keras",
"nemo" = "NeMo",
"pyannote-audio" = "pyannote.audio",
"sentence-transformers" = "Sentence Transformers",
"sklearn" = "Scikit-learn",
"spacy" = "spaCy",
"speechbrain" = "speechbrain",
"tensorflowtts" = "TensorFlowTTS",
"timm" = "Timm",
"fastai" = "fastai",
"transformers" = "Transformers",
"stanza" = "Stanza",
"fasttext" = "fastText",
"stable-baselines3" = "Stable-Baselines3",
"ml-agents" = "ML-Agents",
}
export const ALL_MODEL_LIBRARY_KEYS = Object.keys(ModelLibrary) as (keyof typeof ModelLibrary)[];
/**
* Elements configurable by a model library.
*/
export interface LibraryUiElement {
/**
* Name displayed on the main
* call-to-action button on the model page.
*/
btnLabel: string;
/**
* Repo name
*/
repoName: string;
/**
* URL to library's repo
*/
repoUrl: string;
/**
* Code snippet displayed on model page
*/
snippet: (model: ModelData) => string;
}
function nameWithoutNamespace(modelId: string): string {
const splitted = modelId.split("/");
return splitted.length === 1 ? splitted[0] : splitted[1];
}
//#region snippets
const adapter_transformers = (model: ModelData) =>
`from transformers import ${model.config?.adapter_transformers?.model_class}
model = ${model.config?.adapter_transformers?.model_class}.from_pretrained("${model.config?.adapter_transformers?.{model.id}}")
model.load_adapter("${model.id}", source="hf")`;
const allennlpUnknown = (model: ModelData) =>
`import allennlp_models
from allennlp.predictors.predictor import Predictor
predictor = Predictor.from_path("hf://${model.id}")`;
const allennlpQuestionAnswering = (model: ModelData) =>
`import allennlp_models
from allennlp.predictors.predictor import Predictor
predictor = Predictor.from_path("hf://${model.id}")
predictor_input = {"passage": "My name is Wolfgang and I live in Berlin", "question": "Where do I live?"}
predictions = predictor.predict_json(predictor_input)`;
const allennlp = (model: ModelData) => {
if (model.tags?.includes("question-answering")) {
return allennlpQuestionAnswering(model);
}
return allennlpUnknown(model);
};
const asteroid = (model: ModelData) =>
`from asteroid.models import BaseModel
model = BaseModel.from_pretrained("${model.id}")`;
const diffusers = (model: ModelData) =>
`from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`;
const espnetTTS = (model: ModelData) =>
`from espnet2.bin.tts_inference import Text2Speech
model = Text2Speech.from_pretrained("${model.id}")
speech, *_ = model("text to generate speech from")`;
const espnetASR = (model: ModelData) =>
`from espnet2.bin.asr_inference import Speech2Text
model = Speech2Text.from_pretrained(
"${model.id}"
)
speech, rate = soundfile.read("speech.wav")
text, *_ = model(speech)`;
const espnetUnknown = () =>
`unknown model type (must be text-to-speech or automatic-speech-recognition)`;
const espnet = (model: ModelData) => {
if (model.tags?.includes("text-to-speech")) {
return espnetTTS(model);
} else if (model.tags?.includes("automatic-speech-recognition")) {
return espnetASR(model);
}
return espnetUnknown();
};
const fairseq = (model: ModelData) =>
`from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
"${model.id}"
)`;
const flair = (model: ModelData) =>
`from flair.models import SequenceTagger
tagger = SequenceTagger.load("${model.id}")`;
const keras = (model: ModelData) =>
`from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("${model.id}")
`;
const pyannote_audio_pipeline = (model: ModelData) =>
`from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("${model.id}")
# inference on the whole file
pipeline("file.wav")
# inference on an excerpt
from pyannote.core import Segment
excerpt = Segment(start=2.0, end=5.0)
from pyannote.audio import Audio
waveform, sample_rate = Audio().crop("file.wav", excerpt)
pipeline({"waveform": waveform, "sample_rate": sample_rate})`;
const pyannote_audio_model = (model: ModelData) =>
`from pyannote.audio import Model, Inference
model = Model.from_pretrained("${model.id}")
inference = Inference(model)
# inference on the whole file
inference("file.wav")
# inference on an excerpt
from pyannote.core import Segment
excerpt = Segment(start=2.0, end=5.0)
inference.crop("file.wav", excerpt)`;
const pyannote_audio = (model: ModelData) => {
if (model.tags?.includes("pyannote-audio-pipeline")) {
return pyannote_audio_pipeline(model);
}
return pyannote_audio_model(model);
};
const tensorflowttsTextToMel = (model: ModelData) =>
`from tensorflow_tts.inference import AutoProcessor, TFAutoModel
processor = AutoProcessor.from_pretrained("${model.id}")
model = TFAutoModel.from_pretrained("${model.id}")
`;
const tensorflowttsMelToWav = (model: ModelData) =>
`from tensorflow_tts.inference import TFAutoModel
model = TFAutoModel.from_pretrained("${model.id}")
audios = model.inference(mels)
`;
const tensorflowttsUnknown = (model: ModelData) =>
`from tensorflow_tts.inference import TFAutoModel
model = TFAutoModel.from_pretrained("${model.id}")
`;
const tensorflowtts = (model: ModelData) => {
if (model.tags?.includes("text-to-mel")) {
return tensorflowttsTextToMel(model);
} else if (model.tags?.includes("mel-to-wav")) {
return tensorflowttsMelToWav(model);
}
return tensorflowttsUnknown(model);
};
const timm = (model: ModelData) =>
`import timm
model = timm.create_model("hf_hub:${model.id}", pretrained=True)`;
const sklearn = (model: ModelData) =>
`from huggingface_hub import hf_hub_download
import joblib
model = joblib.load(
hf_hub_download("${model.id}", "sklearn_model.joblib")
)`;
const fastai = (model: ModelData) =>
`from huggingface_hub import from_pretrained_fastai
learn = from_pretrained_fastai("${model.id}")`;
const sentenceTransformers = (model: ModelData) =>
`from sentence_transformers import SentenceTransformer
model = SentenceTransformer("${model.id}")`;
const spacy = (model: ModelData) =>
`!pip install https://huggingface.co/${model.id}/resolve/main/${nameWithoutNamespace(model.id)}-any-py3-none-any.whl
# Using spacy.load().
import spacy
nlp = spacy.load("${nameWithoutNamespace(model.id)}")
# Importing as module.
import ${nameWithoutNamespace(model.id)}
nlp = ${nameWithoutNamespace(model.id)}.load()`;
const stanza = (model: ModelData) =>
`import stanza
stanza.download("${nameWithoutNamespace(model.id).replace("stanza-", "")}")
nlp = stanza.Pipeline("${nameWithoutNamespace(model.id).replace("stanza-", "")}")`;
const speechBrainMethod = (speechbrainInterface: string) => {
switch (speechbrainInterface) {
case "EncoderClassifier":
return "classify_file";
case "EncoderDecoderASR":
case "EncoderASR":
return "transcribe_file";
case "SpectralMaskEnhancement":
return "enhance_file";
case "SepformerSeparation":
return "separate_file";
default:
return undefined;
}
};
const speechbrain = (model: ModelData) => {
const speechbrainInterface = model.config?.speechbrain?.interface;
if (speechbrainInterface === undefined) {
return `# interface not specified in config.json`;
}
const speechbrainMethod = speechBrainMethod(speechbrainInterface);
if (speechbrainMethod === undefined) {
return `# interface in config.json invalid`;
}
return `from speechbrain.pretrained import ${speechbrainInterface}
model = ${speechbrainInterface}.from_hparams(
"${model.id}"
)
model.${speechbrainMethod}("file.wav")`;
};
const transformers = (model: ModelData) => {
const info = model.transformersInfo;
if (!info) {
return `# ⚠️ Type of model unknown`;
}
if (info.processor) {
const varName = info.processor === "AutoTokenizer" ? "tokenizer"
: info.processor === "AutoFeatureExtractor" ? "extractor"
: "processor"
;
return [
`from transformers import ${info.processor}, ${info.auto_model}`,
"",
`${varName} = ${info.processor}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`,
"",
`model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`,
].join("\n");
} else {
return [
`from transformers import ${info.auto_model}`,
"",
`model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`,
].join("\n");
}
};
const fasttext = (model: ModelData) =>
`from huggingface_hub import hf_hub_download
import fasttext
model = fasttext.load_model(hf_hub_download("${model.id}", "model.bin"))`;
const stableBaselines3 = (model: ModelData) =>
`from huggingface_sb3 import load_from_hub
checkpoint = load_from_hub(
repo_id="${model.id}",
filename="{MODEL FILENAME}.zip",
)`;
const nemoDomainResolver = (domain: string, model: ModelData): string | undefined => {
const modelName = `${nameWithoutNamespace(model.id)}.nemo`;
switch (domain) {
case "ASR":
return `import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.ASRModel.from_pretrained("${model.id}")
transcriptions = asr_model.transcribe(["file.wav"])`;
default:
return undefined;
}
};
const mlAgents = (model: ModelData) =>
`mlagents-load-from-hf --repo-id="${model.id}" --local-dir="./downloads"`;
const nemo = (model: ModelData) => {
let command: string | undefined = undefined;
// Resolve the tag to a nemo domain/sub-domain
if (model.tags?.includes("automatic-speech-recognition")) {
command = nemoDomainResolver("ASR", model);
}
return command ?? `# tag did not correspond to a valid NeMo domain.`;
};
//#endregion
export const MODEL_LIBRARIES_UI_ELEMENTS: { [key in keyof typeof ModelLibrary]?: LibraryUiElement } = {
// ^^ TODO(remove the optional ? marker when Stanza snippet is available)
"adapter-transformers": {
btnLabel: "Adapter Transformers",
repoName: "adapter-transformers",
repoUrl: "https://github.com/Adapter-Hub/adapter-transformers",
snippet: adapter_transformers,
},
"allennlp": {
btnLabel: "AllenNLP",
repoName: "AllenNLP",
repoUrl: "https://github.com/allenai/allennlp",
snippet: allennlp,
},
"asteroid": {
btnLabel: "Asteroid",
repoName: "Asteroid",
repoUrl: "https://github.com/asteroid-team/asteroid",
snippet: asteroid,
},
"diffusers": {
btnLabel: "Diffusers",
repoName: "🤗/diffusers",
repoUrl: "https://github.com/huggingface/diffusers",
snippet: diffusers,
},
"espnet": {
btnLabel: "ESPnet",
repoName: "ESPnet",
repoUrl: "https://github.com/espnet/espnet",
snippet: espnet,
},
"fairseq": {
btnLabel: "Fairseq",
repoName: "fairseq",
repoUrl: "https://github.com/pytorch/fairseq",
snippet: fairseq,
},
"flair": {
btnLabel: "Flair",
repoName: "Flair",
repoUrl: "https://github.com/flairNLP/flair",
snippet: flair,
},
"keras": {
btnLabel: "Keras",
repoName: "Keras",
repoUrl: "https://github.com/keras-team/keras",
snippet: keras,
},
"nemo": {
btnLabel: "NeMo",
repoName: "NeMo",
repoUrl: "https://github.com/NVIDIA/NeMo",
snippet: nemo,
},
"pyannote-audio": {
btnLabel: "pyannote.audio",
repoName: "pyannote-audio",
repoUrl: "https://github.com/pyannote/pyannote-audio",
snippet: pyannote_audio,
},
"sentence-transformers": {
btnLabel: "sentence-transformers",
repoName: "sentence-transformers",
repoUrl: "https://github.com/UKPLab/sentence-transformers",
snippet: sentenceTransformers,
},
"sklearn": {
btnLabel: "Scikit-learn",
repoName: "Scikit-learn",
repoUrl: "https://github.com/scikit-learn/scikit-learn",
snippet: sklearn,
},
"fastai": {
btnLabel: "fastai",
repoName: "fastai",
repoUrl: "https://github.com/fastai/fastai",
snippet: fastai,
},
"spacy": {
btnLabel: "spaCy",
repoName: "spaCy",
repoUrl: "https://github.com/explosion/spaCy",
snippet: spacy,
},
"speechbrain": {
btnLabel: "speechbrain",
repoName: "speechbrain",
repoUrl: "https://github.com/speechbrain/speechbrain",
snippet: speechbrain,
},
"stanza": {
btnLabel: "Stanza",
repoName: "stanza",
repoUrl: "https://github.com/stanfordnlp/stanza",
snippet: stanza,
},
"tensorflowtts": {
btnLabel: "TensorFlowTTS",
repoName: "TensorFlowTTS",
repoUrl: "https://github.com/TensorSpeech/TensorFlowTTS",
snippet: tensorflowtts,
},
"timm": {
btnLabel: "timm",
repoName: "pytorch-image-models",
repoUrl: "https://github.com/rwightman/pytorch-image-models",
snippet: timm,
},
"transformers": {
btnLabel: "Transformers",
repoName: "🤗/transformers",
repoUrl: "https://github.com/huggingface/transformers",
snippet: transformers,
},
"fasttext": {
btnLabel: "fastText",
repoName: "fastText",
repoUrl: "https://fasttext.cc/",
snippet: fasttext,
},
"stable-baselines3": {
btnLabel: "stable-baselines3",
repoName: "stable-baselines3",
repoUrl: "https://github.com/huggingface/huggingface_sb3",
snippet: stableBaselines3,
},
"ml-agents": {
btnLabel: "ml-agents",
repoName: "ml-agents",
repoUrl: "https://github.com/huggingface/ml-agents",
snippet: mlAgents,
},
} as const;
"""
if __name__ == '__main__':
import sys
library_name = "keras"
model_name = "Distillgpt2"
print(read_file(library_name, model_name))
""""
try:
args = sys.argv[1:]
if args:
print(read_file(args[0], args[1]))
except IndexError:
pass
"""