import csv import os from datetime import datetime from typing import Optional, Union import gradio as gr from huggingface_hub import HfApi, Repository from optimum_neuron_export import convert from gradio_huggingfacehub_search import HuggingfaceHubSearch from apscheduler.schedulers.background import BackgroundScheduler DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/neuron-exports" DATA_FILENAME = "exports.csv" DATA_FILE = os.path.join("data", DATA_FILENAME) HF_TOKEN = os.environ.get("HF_WRITE_TOKEN") DATADIR = "neuron_exports_data" repo: Optional[Repository] = None # Uncomment if you want to push to dataset repo with token # if HF_TOKEN: # repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN) # Define all possible tasks and their categories for coloring TASK_CATEGORIES = { "auto": {"color": "#6b7280", "category": "Auto"}, "feature-extraction": {"color": "#3b82f6", "category": "Feature Extraction"}, "fill-mask": {"color": "#8b5cf6", "category": "NLP"}, "multiple-choice": {"color": "#8b5cf6", "category": "NLP"}, "question-answering": {"color": "#8b5cf6", "category": "NLP"}, "text-classification": {"color": "#8b5cf6", "category": "NLP"}, "token-classification": {"color": "#8b5cf6", "category": "NLP"}, "text-generation": {"color": "#10b981", "category": "Text Generation"}, "text2text-generation": {"color": "#10b981", "category": "Text Generation"}, "audio-classification": {"color": "#f59e0b", "category": "Audio"}, "automatic-speech-recognition": {"color": "#f59e0b", "category": "Audio"}, "audio-frame-classification": {"color": "#f59e0b", "category": "Audio"}, "audio-xvector": {"color": "#f59e0b", "category": "Audio"}, "image-classification": {"color": "#ef4444", "category": "Vision"}, "object-detection": {"color": "#ef4444", "category": "Vision"}, "semantic-segmentation": {"color": "#ef4444", "category": "Vision"}, "text-to-image": {"color": "#ec4899", "category": "Multimodal"}, "image-to-image": {"color": "#ec4899", "category": "Multimodal"}, "inpaint": {"color": "#ec4899", "category": "Multimodal"}, "zero-shot-image-classification": {"color": "#ec4899", "category": "Multimodal"}, "sentence-similarity": {"color": "#06b6d4", "category": "Similarity"}, } TAGS = { "Feature Extraction": {"color": "#3b82f6", "category": "Feature Extraction"}, "NLP": {"color": "#8b5cf6", "category": "NLP"}, "Text Generation": {"color": "#10b981", "category": "Text Generation"}, "Audio": {"color": "#f59e0b", "category": "Audio"}, "Vision": {"color": "#ef4444", "category": "Vision"}, "Multimodal": {"color": "#ec4899", "category": "Multimodal"}, "Similarity": {"color": "#06b6d4", "category": "Similarity"}, } # Get all tasks for dropdown ALL_TASKS = list(TASK_CATEGORIES.keys()) def create_task_tag(task: str) -> str: """Create a colored HTML tag for a task""" if task in TASK_CATEGORIES: color = TASK_CATEGORIES[task]["color"] return f'{task}' elif task in TAGS: color = TAGS[task]["color"] return f'{task}' else: return f'{task}' def format_tasks_for_table(tasks_str: str) -> str: """Convert comma-separated tasks into colored tags""" tasks = [task.strip() for task in tasks_str.split(',')] return ' '.join([create_task_tag(task) for task in tasks]) def neuron_export(model_id: str, task: str, oauth_token: gr.OAuthToken) -> str: if oauth_token.token is None: return "You must be logged in to use this space" if not model_id: return f"### Invalid input ๐Ÿž Please specify a model name, got {model_id}" try: api = HfApi(token=oauth_token.token) error, commit_info = convert(api=api, model_id=model_id, task=task, token=oauth_token.token) if error != "0": return error print("[commit_info]", commit_info) # Save in a private dataset if repo initialized if repo is not None: repo.git_pull(rebase=True) with open(os.path.join(DATADIR, DATA_FILE), "a") as csvfile: writer = csv.DictWriter( csvfile, fieldnames=["model_id", "pr_url", "time"] ) writer.writerow( { "model_id": model_id, "pr_url": commit_info.pr_url, "time": str(datetime.now()), } ) commit_url = repo.push_to_hub() print("[dataset]", commit_url) pr_revision = commit_info.pr_revision.replace("/", "%2F") return f"#### Success ๐Ÿ”ฅ Yay! This model was successfully exported and a PR was opened using your token: [{commit_info.pr_url}]({commit_info.pr_url}). If you would like to use the exported model without waiting for the PR to be approved, head to https://huggingface.co/{model_id}/tree/{pr_revision}" except Exception as e: return f"#### Error: {e}" TITLE_IMAGE = """
""" TITLE = """

๐Ÿค— Optimum Neuron Model Exporter ๐ŸŽ๏ธ (WIP)

""" DESCRIPTION = """ This Space allows you to automatically export ๐Ÿค— transformers models hosted on the Hugging Face Hub to AWS Neuron-optimized format for Inferentia/Trainium acceleration. It opens a PR on the target model, and it is up to the owner of the original model to merge the PR to allow people to leverage Neuron optimization! **Features:** - Automatically opens PR with Neuron-optimized model - Preserves original model weights - Adds proper tags to model card **Requirements:** - Model must be compatible with [Optimum Neuron](https://huggingface.co/docs/optimum-neuron) - User must be logged in with write token """ # Custom CSS to fix dark mode compatibility and transparency issues CUSTOM_CSS = """ /* Fix for HuggingfaceHubSearch component visibility in both light and dark modes */ .gradio-container .gr-form { background: var(--background-fill-primary) !important; border: 1px solid var(--border-color-primary) !important; } /* Ensure text is visible in both modes */ .gradio-container input[type="text"], .gradio-container textarea, .gradio-container .gr-textbox input { color: var(--body-text-color) !important; background: var(--input-background-fill) !important; border: 1px solid var(--border-color-primary) !important; } /* Fix dropdown/search results visibility */ .gradio-container .gr-dropdown, .gradio-container .gr-dropdown .gr-box, .gradio-container [data-testid="textbox"] { background: var(--background-fill-primary) !important; color: var(--body-text-color) !important; border: 1px solid var(--border-color-primary) !important; } /* Fix for search component specifically */ .gradio-container .gr-form > div, .gradio-container .gr-form input { background: var(--input-background-fill) !important; color: var(--body-text-color) !important; } /* Ensure proper contrast for placeholder text */ .gradio-container input::placeholder { color: var(--body-text-color-subdued) !important; opacity: 0.7; } /* Fix any remaining transparent backgrounds */ .gradio-container .gr-box, .gradio-container .gr-panel { background: var(--background-fill-primary) !important; } /* Make sure search results are visible */ .gradio-container .gr-dropdown-item { color: var(--body-text-color) !important; background: var(--background-fill-primary) !important; } .gradio-container .gr-dropdown-item:hover { background: var(--background-fill-secondary) !important; } /* Task tag styling improvements */ .task-tags { line-height: 1.8; } .task-tags span { display: inline-block; margin: 2px; } """ with gr.Blocks(css=CUSTOM_CSS) as demo: # Login requirement notice and button gr.Markdown("**You must be logged in to use this space**") gr.LoginButton(min_width=250) # Centered title and image gr.HTML(TITLE_IMAGE) gr.HTML(TITLE) # Full-width description gr.Markdown(DESCRIPTION) with gr.Tabs(): with gr.Tab("Export Model"): # Input controls in a row with gr.Row(): input_model = HuggingfaceHubSearch( label="Hub model ID", placeholder="Search for model ID on the hub", search_type="model", ) input_task = gr.Dropdown( choices=ALL_TASKS, value="auto", label='Task (auto could infer task from model)', ) # Export button below the inputs btn = gr.Button("Export to Neuron", size="lg") # Output section output = gr.Markdown(label="Output") btn.click( fn=neuron_export, inputs=[input_model, input_task], outputs=output, ) with gr.Tab("Supported Architectures"): gr.HTML(f"""

๐ŸŽจ Task Categories Legend

{create_task_tag("Feature Extraction")} {create_task_tag("NLP")} {create_task_tag("Text Generation")} {create_task_tag("Audio")} {create_task_tag("Vision")} {create_task_tag("Multimodal")} {create_task_tag("Similarity")}
""") gr.HTML(f"""

๐Ÿค— Transformers

Architecture Supported Tasks
ALBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
AST{format_tasks_for_table("feature-extraction, audio-classification")}
BERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
BLOOM{format_tasks_for_table("text-generation")}
Beit{format_tasks_for_table("feature-extraction, image-classification")}
CamemBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
CLIP{format_tasks_for_table("feature-extraction, image-classification")}
ConvBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
ConvNext{format_tasks_for_table("feature-extraction, image-classification")}
ConvNextV2{format_tasks_for_table("feature-extraction, image-classification")}
CvT{format_tasks_for_table("feature-extraction, image-classification")}
DeBERTa (INF2 only){format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
DeBERTa-v2 (INF2 only){format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
Deit{format_tasks_for_table("feature-extraction, image-classification")}
DistilBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
DonutSwin{format_tasks_for_table("feature-extraction")}
Dpt{format_tasks_for_table("feature-extraction")}
ELECTRA{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
ESM{format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")}
FlauBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
GPT2{format_tasks_for_table("text-generation")}
Hubert{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")}
Levit{format_tasks_for_table("feature-extraction, image-classification")}
Llama, Llama 2, Llama 3{format_tasks_for_table("text-generation")}
Mistral{format_tasks_for_table("text-generation")}
Mixtral{format_tasks_for_table("text-generation")}
MobileBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
MobileNetV2{format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")}
MobileViT{format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")}
ModernBERT{format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")}
MPNet{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
OPT{format_tasks_for_table("text-generation")}
Phi{format_tasks_for_table("feature-extraction, text-classification, token-classification")}
RoBERTa{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
RoFormer{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
Swin{format_tasks_for_table("feature-extraction, image-classification")}
T5{format_tasks_for_table("text2text-generation")}
UniSpeech{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")}
UniSpeech-SAT{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}
ViT{format_tasks_for_table("feature-extraction, image-classification")}
Wav2Vec2{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}
WavLM{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}
Whisper{format_tasks_for_table("automatic-speech-recognition")}
XLM{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
XLM-RoBERTa{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
Yolos{format_tasks_for_table("feature-extraction, object-detection")}

๐Ÿงจ Diffusers

Architecture Supported Tasks
Stable Diffusion{format_tasks_for_table("text-to-image, image-to-image, inpaint")}
Stable Diffusion XL Base{format_tasks_for_table("text-to-image, image-to-image, inpaint")}
Stable Diffusion XL Refiner{format_tasks_for_table("image-to-image, inpaint")}
SDXL Turbo{format_tasks_for_table("text-to-image, image-to-image, inpaint")}
LCM{format_tasks_for_table("text-to-image")}
PixArt-ฮฑ{format_tasks_for_table("text-to-image")}
PixArt-ฮฃ{format_tasks_for_table("text-to-image")}

๐Ÿค– Sentence Transformers

Architecture Supported Tasks
Transformer{format_tasks_for_table("feature-extraction, sentence-similarity")}
CLIP{format_tasks_for_table("feature-extraction, zero-shot-image-classification")}

๐Ÿ’ก Note: Some architectures may have specific requirements or limitations. DeBERTa models are only supported on INF2 instances.

For more details, check the Optimum Neuron documentation.

""") # Add spacing between tabs and content gr.Markdown("



") if __name__ == "__main__": def restart_space(): if HF_TOKEN: HfApi().restart_space(repo_id="optimum/neuron-export", token=HF_TOKEN, factory_reboot=True) scheduler = BackgroundScheduler() scheduler.add_job(restart_space, "interval", seconds=21600) # Restart every 6 hours scheduler.start() demo.launch()