#!/usr/bin/env python3 import re """ Extracts code from the file "./Libraries.ts". (Note that "Libraries.ts", must be in the same directory as this script). """ file = None def read_file(library: str, model_name: str) -> str: text = file match = re.search('const ' + library + '.*', text, re.DOTALL).group() if match: text = match[match.index('`') + 1:match.index('`;')].replace('${model.id}', model_name) return text file = """ import type { ModelData } from "./Types"; /** * Add your new library here. */ export enum ModelLibrary { "adapter-transformers" = "Adapter Transformers", "allennlp" = "allenNLP", "asteroid" = "Asteroid", "diffusers" = "Diffusers", "espnet" = "ESPnet", "fairseq" = "Fairseq", "flair" = "Flair", "keras" = "Keras", "nemo" = "NeMo", "pyannote-audio" = "pyannote.audio", "sentence-transformers" = "Sentence Transformers", "sklearn" = "Scikit-learn", "spacy" = "spaCy", "speechbrain" = "speechbrain", "tensorflowtts" = "TensorFlowTTS", "timm" = "Timm", "fastai" = "fastai", "transformers" = "Transformers", "stanza" = "Stanza", "fasttext" = "fastText", "stable-baselines3" = "Stable-Baselines3", "ml-agents" = "ML-Agents", } export const ALL_MODEL_LIBRARY_KEYS = Object.keys(ModelLibrary) as (keyof typeof ModelLibrary)[]; /** * Elements configurable by a model library. */ export interface LibraryUiElement { /** * Name displayed on the main * call-to-action button on the model page. */ btnLabel: string; /** * Repo name */ repoName: string; /** * URL to library's repo */ repoUrl: string; /** * Code snippet displayed on model page */ snippet: (model: ModelData) => string; } function nameWithoutNamespace(modelId: string): string { const splitted = modelId.split("/"); return splitted.length === 1 ? splitted[0] : splitted[1]; } //#region snippets const adapter_transformers = (model: ModelData) => `from transformers import ${model.config?.adapter_transformers?.model_class} model = ${model.config?.adapter_transformers?.model_class}.from_pretrained("${model.config?.adapter_transformers?.{model.id}}") model.load_adapter("${model.id}", source="hf")`; const allennlpUnknown = (model: ModelData) => `import allennlp_models from allennlp.predictors.predictor import Predictor predictor = Predictor.from_path("hf://${model.id}")`; const allennlpQuestionAnswering = (model: ModelData) => `import allennlp_models from allennlp.predictors.predictor import Predictor predictor = Predictor.from_path("hf://${model.id}") predictor_input = {"passage": "My name is Wolfgang and I live in Berlin", "question": "Where do I live?"} predictions = predictor.predict_json(predictor_input)`; const allennlp = (model: ModelData) => { if (model.tags?.includes("question-answering")) { return allennlpQuestionAnswering(model); } return allennlpUnknown(model); }; const asteroid = (model: ModelData) => `from asteroid.models import BaseModel model = BaseModel.from_pretrained("${model.id}")`; const diffusers = (model: ModelData) => `from diffusers import DiffusionPipeline pipeline = DiffusionPipeline.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`; const espnetTTS = (model: ModelData) => `from espnet2.bin.tts_inference import Text2Speech model = Text2Speech.from_pretrained("${model.id}") speech, *_ = model("text to generate speech from")`; const espnetASR = (model: ModelData) => `from espnet2.bin.asr_inference import Speech2Text model = Speech2Text.from_pretrained( "${model.id}" ) speech, rate = soundfile.read("speech.wav") text, *_ = model(speech)`; const espnetUnknown = () => `unknown model type (must be text-to-speech or automatic-speech-recognition)`; const espnet = (model: ModelData) => { if (model.tags?.includes("text-to-speech")) { return espnetTTS(model); } else if (model.tags?.includes("automatic-speech-recognition")) { return espnetASR(model); } return espnetUnknown(); }; const fairseq = (model: ModelData) => `from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub models, cfg, task = load_model_ensemble_and_task_from_hf_hub( "${model.id}" )`; const flair = (model: ModelData) => `from flair.models import SequenceTagger tagger = SequenceTagger.load("${model.id}")`; const keras = (model: ModelData) => `from huggingface_hub import from_pretrained_keras model = from_pretrained_keras("${model.id}") `; const pyannote_audio_pipeline = (model: ModelData) => `from pyannote.audio import Pipeline pipeline = Pipeline.from_pretrained("${model.id}") # inference on the whole file pipeline("file.wav") # inference on an excerpt from pyannote.core import Segment excerpt = Segment(start=2.0, end=5.0) from pyannote.audio import Audio waveform, sample_rate = Audio().crop("file.wav", excerpt) pipeline({"waveform": waveform, "sample_rate": sample_rate})`; const pyannote_audio_model = (model: ModelData) => `from pyannote.audio import Model, Inference model = Model.from_pretrained("${model.id}") inference = Inference(model) # inference on the whole file inference("file.wav") # inference on an excerpt from pyannote.core import Segment excerpt = Segment(start=2.0, end=5.0) inference.crop("file.wav", excerpt)`; const pyannote_audio = (model: ModelData) => { if (model.tags?.includes("pyannote-audio-pipeline")) { return pyannote_audio_pipeline(model); } return pyannote_audio_model(model); }; const tensorflowttsTextToMel = (model: ModelData) => `from tensorflow_tts.inference import AutoProcessor, TFAutoModel processor = AutoProcessor.from_pretrained("${model.id}") model = TFAutoModel.from_pretrained("${model.id}") `; const tensorflowttsMelToWav = (model: ModelData) => `from tensorflow_tts.inference import TFAutoModel model = TFAutoModel.from_pretrained("${model.id}") audios = model.inference(mels) `; const tensorflowttsUnknown = (model: ModelData) => `from tensorflow_tts.inference import TFAutoModel model = TFAutoModel.from_pretrained("${model.id}") `; const tensorflowtts = (model: ModelData) => { if (model.tags?.includes("text-to-mel")) { return tensorflowttsTextToMel(model); } else if (model.tags?.includes("mel-to-wav")) { return tensorflowttsMelToWav(model); } return tensorflowttsUnknown(model); }; const timm = (model: ModelData) => `import timm model = timm.create_model("hf_hub:${model.id}", pretrained=True)`; const sklearn = (model: ModelData) => `from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("${model.id}", "sklearn_model.joblib") )`; const fastai = (model: ModelData) => `from huggingface_hub import from_pretrained_fastai learn = from_pretrained_fastai("${model.id}")`; const sentenceTransformers = (model: ModelData) => `from sentence_transformers import SentenceTransformer model = SentenceTransformer("${model.id}")`; const spacy = (model: ModelData) => `!pip install https://huggingface.co/${model.id}/resolve/main/${nameWithoutNamespace(model.id)}-any-py3-none-any.whl # Using spacy.load(). import spacy nlp = spacy.load("${nameWithoutNamespace(model.id)}") # Importing as module. import ${nameWithoutNamespace(model.id)} nlp = ${nameWithoutNamespace(model.id)}.load()`; const stanza = (model: ModelData) => `import stanza stanza.download("${nameWithoutNamespace(model.id).replace("stanza-", "")}") nlp = stanza.Pipeline("${nameWithoutNamespace(model.id).replace("stanza-", "")}")`; const speechBrainMethod = (speechbrainInterface: string) => { switch (speechbrainInterface) { case "EncoderClassifier": return "classify_file"; case "EncoderDecoderASR": case "EncoderASR": return "transcribe_file"; case "SpectralMaskEnhancement": return "enhance_file"; case "SepformerSeparation": return "separate_file"; default: return undefined; } }; const speechbrain = (model: ModelData) => { const speechbrainInterface = model.config?.speechbrain?.interface; if (speechbrainInterface === undefined) { return `# interface not specified in config.json`; } const speechbrainMethod = speechBrainMethod(speechbrainInterface); if (speechbrainMethod === undefined) { return `# interface in config.json invalid`; } return `from speechbrain.pretrained import ${speechbrainInterface} model = ${speechbrainInterface}.from_hparams( "${model.id}" ) model.${speechbrainMethod}("file.wav")`; }; const transformers = (model: ModelData) => { const info = model.transformersInfo; if (!info) { return `# ⚠️ Type of model unknown`; } if (info.processor) { const varName = info.processor === "AutoTokenizer" ? "tokenizer" : info.processor === "AutoFeatureExtractor" ? "extractor" : "processor" ; return [ `from transformers import ${info.processor}, ${info.auto_model}`, "", `${varName} = ${info.processor}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, "", `model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, ].join("\n"); } else { return [ `from transformers import ${info.auto_model}`, "", `model = ${info.auto_model}.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""})`, ].join("\n"); } }; const fasttext = (model: ModelData) => `from huggingface_hub import hf_hub_download import fasttext model = fasttext.load_model(hf_hub_download("${model.id}", "model.bin"))`; const stableBaselines3 = (model: ModelData) => `from huggingface_sb3 import load_from_hub checkpoint = load_from_hub( repo_id="${model.id}", filename="{MODEL FILENAME}.zip", )`; const nemoDomainResolver = (domain: string, model: ModelData): string | undefined => { const modelName = `${nameWithoutNamespace(model.id)}.nemo`; switch (domain) { case "ASR": return `import nemo.collections.asr as nemo_asr asr_model = nemo_asr.models.ASRModel.from_pretrained("${model.id}") transcriptions = asr_model.transcribe(["file.wav"])`; default: return undefined; } }; const mlAgents = (model: ModelData) => `mlagents-load-from-hf --repo-id="${model.id}" --local-dir="./downloads"`; const nemo = (model: ModelData) => { let command: string | undefined = undefined; // Resolve the tag to a nemo domain/sub-domain if (model.tags?.includes("automatic-speech-recognition")) { command = nemoDomainResolver("ASR", model); } return command ?? `# tag did not correspond to a valid NeMo domain.`; }; //#endregion export const MODEL_LIBRARIES_UI_ELEMENTS: { [key in keyof typeof ModelLibrary]?: LibraryUiElement } = { // ^^ TODO(remove the optional ? marker when Stanza snippet is available) "adapter-transformers": { btnLabel: "Adapter Transformers", repoName: "adapter-transformers", repoUrl: "https://github.com/Adapter-Hub/adapter-transformers", snippet: adapter_transformers, }, "allennlp": { btnLabel: "AllenNLP", repoName: "AllenNLP", repoUrl: "https://github.com/allenai/allennlp", snippet: allennlp, }, "asteroid": { btnLabel: "Asteroid", repoName: "Asteroid", repoUrl: "https://github.com/asteroid-team/asteroid", snippet: asteroid, }, "diffusers": { btnLabel: "Diffusers", repoName: "🤗/diffusers", repoUrl: "https://github.com/huggingface/diffusers", snippet: diffusers, }, "espnet": { btnLabel: "ESPnet", repoName: "ESPnet", repoUrl: "https://github.com/espnet/espnet", snippet: espnet, }, "fairseq": { btnLabel: "Fairseq", repoName: "fairseq", repoUrl: "https://github.com/pytorch/fairseq", snippet: fairseq, }, "flair": { btnLabel: "Flair", repoName: "Flair", repoUrl: "https://github.com/flairNLP/flair", snippet: flair, }, "keras": { btnLabel: "Keras", repoName: "Keras", repoUrl: "https://github.com/keras-team/keras", snippet: keras, }, "nemo": { btnLabel: "NeMo", repoName: "NeMo", repoUrl: "https://github.com/NVIDIA/NeMo", snippet: nemo, }, "pyannote-audio": { btnLabel: "pyannote.audio", repoName: "pyannote-audio", repoUrl: "https://github.com/pyannote/pyannote-audio", snippet: pyannote_audio, }, "sentence-transformers": { btnLabel: "sentence-transformers", repoName: "sentence-transformers", repoUrl: "https://github.com/UKPLab/sentence-transformers", snippet: sentenceTransformers, }, "sklearn": { btnLabel: "Scikit-learn", repoName: "Scikit-learn", repoUrl: "https://github.com/scikit-learn/scikit-learn", snippet: sklearn, }, "fastai": { btnLabel: "fastai", repoName: "fastai", repoUrl: "https://github.com/fastai/fastai", snippet: fastai, }, "spacy": { btnLabel: "spaCy", repoName: "spaCy", repoUrl: "https://github.com/explosion/spaCy", snippet: spacy, }, "speechbrain": { btnLabel: "speechbrain", repoName: "speechbrain", repoUrl: "https://github.com/speechbrain/speechbrain", snippet: speechbrain, }, "stanza": { btnLabel: "Stanza", repoName: "stanza", repoUrl: "https://github.com/stanfordnlp/stanza", snippet: stanza, }, "tensorflowtts": { btnLabel: "TensorFlowTTS", repoName: "TensorFlowTTS", repoUrl: "https://github.com/TensorSpeech/TensorFlowTTS", snippet: tensorflowtts, }, "timm": { btnLabel: "timm", repoName: "pytorch-image-models", repoUrl: "https://github.com/rwightman/pytorch-image-models", snippet: timm, }, "transformers": { btnLabel: "Transformers", repoName: "🤗/transformers", repoUrl: "https://github.com/huggingface/transformers", snippet: transformers, }, "fasttext": { btnLabel: "fastText", repoName: "fastText", repoUrl: "https://fasttext.cc/", snippet: fasttext, }, "stable-baselines3": { btnLabel: "stable-baselines3", repoName: "stable-baselines3", repoUrl: "https://github.com/huggingface/huggingface_sb3", snippet: stableBaselines3, }, "ml-agents": { btnLabel: "ml-agents", repoName: "ml-agents", repoUrl: "https://github.com/huggingface/ml-agents", snippet: mlAgents, }, } as const; """ if __name__ == '__main__': import sys library_name = "keras" model_name = "Distillgpt2" print(read_file(library_name, model_name)) """" try: args = sys.argv[1:] if args: print(read_file(args[0], args[1])) except IndexError: pass """