Spaces:
Running
on
Zero
Running
on
Zero
Add documentation
Browse files- README.md +180 -3
- app.py +17 -0
- automatic_speech_recognition.py +33 -1
- chatbot.py +77 -3
- image_classification.py +33 -1
- image_to_text.py +27 -1
- text_to_image.py +20 -1
- text_to_speech.py +26 -1
- utils.py +144 -0
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 👀
|
| 4 |
colorFrom: purple
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
|
@@ -11,4 +11,181 @@ license: wtfpl
|
|
| 11 |
short_description: A gallery of building blocks for building AI applications
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: AI Building Blocks
|
| 3 |
emoji: 👀
|
| 4 |
colorFrom: purple
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
|
|
|
| 11 |
short_description: A gallery of building blocks for building AI applications
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# AI Building Blocks
|
| 15 |
+
|
| 16 |
+
A gallery of AI building blocks for building AI applications, featuring a Gradio web interface with multiple tabs for different AI tasks.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
This application provides the following AI building blocks:
|
| 21 |
+
|
| 22 |
+
- **Text-to-image Generation**: Generate images from text prompts using Hugging Face Inference API
|
| 23 |
+
- **Image-to-text (Image Captioning)**: Generate text descriptions of images using BLIP models
|
| 24 |
+
- **Image Classification**: Classify recyclable items using Trash-Net model
|
| 25 |
+
- **Text-to-speech (TTS)**: Convert text to speech audio
|
| 26 |
+
- **Automatic Speech Recognition (ASR)**: Transcribe audio to text using Whisper models
|
| 27 |
+
- **Chatbot**: Have conversations with AI chatbots supporting both modern chat models and seq2seq models
|
| 28 |
+
|
| 29 |
+
## Prerequisites
|
| 30 |
+
|
| 31 |
+
- Python 3.8 or higher
|
| 32 |
+
- PyTorch with hardware acceleration (strongly recommended - see [PyTorch Installation](#pytorch-installation))
|
| 33 |
+
- CUDA-capable GPU (optional, but recommended for better performance)
|
| 34 |
+
|
| 35 |
+
## Installation
|
| 36 |
+
|
| 37 |
+
1. Clone this repository:
|
| 38 |
+
```bash
|
| 39 |
+
git clone <repository-url>
|
| 40 |
+
cd ai-building-blocks
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
2. Create a virtual environment:
|
| 44 |
+
```bash
|
| 45 |
+
python -m venv .venv
|
| 46 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
3. Install PyTorch with CUDA support (see [PyTorch Installation](#pytorch-installation) below).
|
| 50 |
+
|
| 51 |
+
4. Install the remaining dependencies:
|
| 52 |
+
```bash
|
| 53 |
+
pip install -r requirements.txt
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## PyTorch Installation
|
| 57 |
+
|
| 58 |
+
PyTorch is not included in `requirements.txt` because installation varies based on your hardware and operating system. **It is strongly recommended to install PyTorch with hardware acceleration support** for optimal performance.
|
| 59 |
+
|
| 60 |
+
For official installation instructions with CUDA support, please visit:
|
| 61 |
+
- **Official PyTorch Installation Guide**: https://pytorch.org/get-started/locally/
|
| 62 |
+
|
| 63 |
+
Select your platform, package manager, Python version, and CUDA version to get the appropriate installation command. For example:
|
| 64 |
+
|
| 65 |
+
- **CUDA 12.1** (recommended for modern NVIDIA GPUs):
|
| 66 |
+
```bash
|
| 67 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
- **CUDA 11.8**:
|
| 71 |
+
```bash
|
| 72 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
- **CPU only** (not recommended for production):
|
| 76 |
+
```bash
|
| 77 |
+
pip install torch torchvision torchaudio
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Configuration
|
| 81 |
+
|
| 82 |
+
Create a `.env` file in the project root directory with the following environment variables:
|
| 83 |
+
|
| 84 |
+
### Required Environment Variables
|
| 85 |
+
|
| 86 |
+
```env
|
| 87 |
+
# Hugging Face API Token (required for Inference API access)
|
| 88 |
+
# Get your token from: https://huggingface.co/settings/tokens
|
| 89 |
+
HF_TOKEN=your_huggingface_token_here
|
| 90 |
+
|
| 91 |
+
# Model IDs for each building block
|
| 92 |
+
TEXT_TO_IMAGE_MODEL=model_id_for_text_to_image
|
| 93 |
+
IMAGE_TO_TEXT_MODEL=model_id_for_image_captioning
|
| 94 |
+
IMAGE_CLASSIFICATION_MODEL=model_id_for_image_classification
|
| 95 |
+
TEXT_TO_SPEECH_MODEL=model_id_for_text_to_speech
|
| 96 |
+
AUDIO_TRANSCRIPTION_MODEL=model_id_for_speech_recognition
|
| 97 |
+
CHAT_MODEL=model_id_for_chatbot
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
### Optional Environment Variables
|
| 101 |
+
|
| 102 |
+
```env
|
| 103 |
+
# Request timeout in seconds (default: 45)
|
| 104 |
+
REQUEST_TIMEOUT=45
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### Example `.env` File
|
| 108 |
+
|
| 109 |
+
```env
|
| 110 |
+
HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 111 |
+
|
| 112 |
+
# Example model IDs (adjust based on your needs)
|
| 113 |
+
TEXT_TO_IMAGE_MODEL=black-forest-labs/FLUX.1-dev
|
| 114 |
+
IMAGE_CLASSIFICATION_MODEL=prithivMLmods/Trash-Net
|
| 115 |
+
IMAGE_TO_TEXT_MODEL=Salesforce/blip-image-captioning-large
|
| 116 |
+
TEXT_TO_SPEECH_MODEL=kakao-enterprise/vits-ljs
|
| 117 |
+
AUDIO_TRANSCRIPTION_MODEL=openai/whisper-large-v3
|
| 118 |
+
CHAT_MODEL=Qwen/Qwen2.5-1.5B-Instruct
|
| 119 |
+
|
| 120 |
+
REQUEST_TIMEOUT=45
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
**Note**: `.env` should already be included in the `.gitignore` file. Make sure to never `git add --force --` it to prevent committing sensitive tokens.
|
| 124 |
+
|
| 125 |
+
## Running the Application
|
| 126 |
+
|
| 127 |
+
1. Activate your virtual environment (if not already activated):
|
| 128 |
+
```bash
|
| 129 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
2. Run the application:
|
| 133 |
+
```bash
|
| 134 |
+
python app.py
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
3. Open your web browser and navigate to the URL shown in the terminal (typically `http://127.0.0.1:7860`).
|
| 138 |
+
|
| 139 |
+
4. The Gradio interface will display multiple tabs, each corresponding to a different AI building block.
|
| 140 |
+
|
| 141 |
+
## Project Structure
|
| 142 |
+
|
| 143 |
+
```
|
| 144 |
+
ai-building-blocks/
|
| 145 |
+
├── app.py # Main application entry point
|
| 146 |
+
├── text_to_image.py # Text-to-image generation module
|
| 147 |
+
├── image_to_text.py # Image captioning module
|
| 148 |
+
├── image_classification.py # Image classification module
|
| 149 |
+
├── text_to_speech.py # Text-to-speech module
|
| 150 |
+
├── automatic_speech_recognition.py # Speech recognition module
|
| 151 |
+
├── chatbot.py # Chatbot module
|
| 152 |
+
├── utils.py # Utility functions
|
| 153 |
+
├── requirements.txt # Python dependencies
|
| 154 |
+
├── .env # Environment variables (create this)
|
| 155 |
+
└── README.md # This file
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Hardware Acceleration
|
| 159 |
+
|
| 160 |
+
This application is designed to leverage hardware acceleration when available:
|
| 161 |
+
|
| 162 |
+
- **NVIDIA CUDA**: Automatically detected and used if available
|
| 163 |
+
- **AMD ROCm**: Supported via CUDA compatibility
|
| 164 |
+
- **Intel XPU**: Automatically detected if available
|
| 165 |
+
- **Apple Silicon (MPS)**: Automatically detected and used on Apple devices
|
| 166 |
+
- **CPU**: Falls back to CPU if no GPU acceleration is available
|
| 167 |
+
|
| 168 |
+
The application automatically selects the best available device. For optimal performance, especially with local models (image-to-text, text-to-speech, chatbot), a CUDA-capable GPU is strongly recommended. This is _untested_ on other hardware. 😉
|
| 169 |
+
|
| 170 |
+
## Troubleshooting
|
| 171 |
+
|
| 172 |
+
### PyTorch Not Detecting GPU
|
| 173 |
+
|
| 174 |
+
If PyTorch is not detecting your GPU:
|
| 175 |
+
|
| 176 |
+
1. Verify CUDA is installed: `nvidia-smi`
|
| 177 |
+
2. Ensure PyTorch was installed with CUDA support (see [PyTorch Installation](#pytorch-installation))
|
| 178 |
+
3. Check PyTorch CUDA availability: `python -c "import torch; print(torch.cuda.is_available())"`
|
| 179 |
+
|
| 180 |
+
### Missing Environment Variables
|
| 181 |
+
|
| 182 |
+
Ensure all required environment variables are set in your `.env` file. Missing variables will cause the application to fail when trying to use the corresponding feature.
|
| 183 |
+
|
| 184 |
+
### Model Loading Errors
|
| 185 |
+
|
| 186 |
+
If you encounter errors loading models:
|
| 187 |
+
|
| 188 |
+
1. Verify your `HF_TOKEN` is valid and has access to the models. Some models are gated.
|
| 189 |
+
2. Check that model IDs in your `.env` file are correct.
|
| 190 |
+
3. Ensure you have sufficient disk space for model downloads.
|
| 191 |
+
4. For local models, ensure you have sufficient RAM or VRAM.
|
app.py
CHANGED
|
@@ -10,11 +10,28 @@ from text_to_speech import create_text_to_speech_tab
|
|
| 10 |
|
| 11 |
|
| 12 |
class App:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def __init__(self, client: InferenceClient):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
self.client = client
|
| 16 |
|
| 17 |
def run(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
with gr.Blocks(title="AI Building Blocks") as demo:
|
| 19 |
gr.Markdown("# AI Building Blocks")
|
| 20 |
gr.Markdown("A gallery of building blocks for building AI applications")
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class App:
|
| 13 |
+
"""Main application class for the AI Building Blocks Gradio interface.
|
| 14 |
+
|
| 15 |
+
This class orchestrates the entire application by creating the Gradio UI
|
| 16 |
+
and integrating all the individual building block tabs.
|
| 17 |
+
"""
|
| 18 |
|
| 19 |
def __init__(self, client: InferenceClient):
|
| 20 |
+
"""Initialize the App with an InferenceClient instance.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
client: Hugging Face InferenceClient instance for making API calls
|
| 24 |
+
to Hugging Face's inference endpoints.
|
| 25 |
+
"""
|
| 26 |
self.client = client
|
| 27 |
|
| 28 |
def run(self):
|
| 29 |
+
"""Launch the Gradio application with all building block tabs.
|
| 30 |
+
|
| 31 |
+
Creates a Gradio Blocks interface with multiple tabs, each representing
|
| 32 |
+
a different AI building block. The application will block until the
|
| 33 |
+
interface is closed.
|
| 34 |
+
"""
|
| 35 |
with gr.Blocks(title="AI Building Blocks") as demo:
|
| 36 |
gr.Markdown("# AI Building Blocks")
|
| 37 |
gr.Markdown("A gallery of building blocks for building AI applications")
|
automatic_speech_recognition.py
CHANGED
|
@@ -5,6 +5,28 @@ import gradio as gr
|
|
| 5 |
from utils import save_audio_to_temp_file, get_model_sample_rate, request_audio
|
| 6 |
|
| 7 |
def automatic_speech_recognition(client: InferenceClient, audio: tuple[int, bytes]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
temp_file_path = None
|
| 9 |
try:
|
| 10 |
model_id = getenv("AUDIO_TRANSCRIPTION_MODEL")
|
|
@@ -21,7 +43,17 @@ def automatic_speech_recognition(client: InferenceClient, audio: tuple[int, byte
|
|
| 21 |
|
| 22 |
|
| 23 |
def create_asr_tab(client: InferenceClient):
|
| 24 |
-
"""Create the automatic speech recognition tab.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
gr.Markdown("Transcribe audio to text.")
|
| 26 |
audio_transcription_url_input = gr.Textbox(label="Audio URL")
|
| 27 |
audio_transcription_audio_request_button = gr.Button("Get Audio")
|
|
|
|
| 5 |
from utils import save_audio_to_temp_file, get_model_sample_rate, request_audio
|
| 6 |
|
| 7 |
def automatic_speech_recognition(client: InferenceClient, audio: tuple[int, bytes]) -> str:
|
| 8 |
+
"""Transcribe audio to text using Hugging Face Inference API.
|
| 9 |
+
|
| 10 |
+
This function converts speech audio into text transcription. The audio is
|
| 11 |
+
resampled to match the model's expected sample rate, saved to a temporary
|
| 12 |
+
file, and then sent to the Inference API for transcription.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
client: Hugging Face InferenceClient instance for API calls.
|
| 16 |
+
audio: Tuple containing:
|
| 17 |
+
- int: Sample rate of the input audio (e.g., 44100 Hz)
|
| 18 |
+
- bytes: Raw audio data as bytes
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
String containing the transcribed text from the audio.
|
| 22 |
+
|
| 23 |
+
Note:
|
| 24 |
+
- The model ID is determined by the AUDIO_TRANSCRIPTION_MODEL environment variable.
|
| 25 |
+
- Audio is automatically resampled to match the model's expected sample rate.
|
| 26 |
+
- Audio is saved as a WAV file for InferenceClient compatibility.
|
| 27 |
+
- Automatically cleans up temporary files after transcription.
|
| 28 |
+
- Uses openai/whisper-large-v3 or similar ASR models.
|
| 29 |
+
"""
|
| 30 |
temp_file_path = None
|
| 31 |
try:
|
| 32 |
model_id = getenv("AUDIO_TRANSCRIPTION_MODEL")
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def create_asr_tab(client: InferenceClient):
|
| 46 |
+
"""Create the automatic speech recognition tab in the Gradio interface.
|
| 47 |
+
|
| 48 |
+
This function sets up all UI components for automatic speech recognition, including:
|
| 49 |
+
- URL input textbox for fetching audio files from the web
|
| 50 |
+
- Button to retrieve audio from URL
|
| 51 |
+
- Audio input component for uploading or recording audio
|
| 52 |
+
- Transcribe button and output textbox
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
client: Hugging Face InferenceClient instance to pass to the automatic_speech_recognition function.
|
| 56 |
+
"""
|
| 57 |
gr.Markdown("Transcribe audio to text.")
|
| 58 |
audio_transcription_url_input = gr.Textbox(label="Audio URL")
|
| 59 |
audio_transcription_audio_request_button = gr.Button("Get Audio")
|
chatbot.py
CHANGED
|
@@ -9,7 +9,26 @@ _tokenizer = None
|
|
| 9 |
_is_seq2seq = None
|
| 10 |
|
| 11 |
def get_chatbot():
|
| 12 |
-
"""Get or create the chatbot model instance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
global _chatbot, _tokenizer, _is_seq2seq
|
| 14 |
if _chatbot is None:
|
| 15 |
model_id = getenv("CHAT_MODEL")
|
|
@@ -46,6 +65,33 @@ def get_chatbot():
|
|
| 46 |
|
| 47 |
@spaces_gpu
|
| 48 |
def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
model, tokenizer, is_seq2seq = get_chatbot()
|
| 50 |
|
| 51 |
# Initialize conversation history if this is the first message
|
|
@@ -129,7 +175,19 @@ def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, li
|
|
| 129 |
|
| 130 |
|
| 131 |
def create_chatbot_tab():
|
| 132 |
-
"""Create the chatbot tab.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
gr.Markdown("Have a conversation with an AI chatbot.")
|
| 134 |
chatbot_history = gr.State(value=None) # Store the conversation history.
|
| 135 |
chatbot_output = gr.Chatbot(label="Conversation")
|
|
@@ -137,7 +195,23 @@ def create_chatbot_tab():
|
|
| 137 |
chatbot_send_button = gr.Button("Send")
|
| 138 |
|
| 139 |
def chat_interface(message: str, history: list | None, conversation_state: list[dict] | None):
|
| 140 |
-
"""Handle chatbot interaction with Gradio chat format.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
if not message.strip():
|
| 142 |
return history, conversation_state, ""
|
| 143 |
response, updated_conversation = chat(message, conversation_state) # Get response from chatbot.
|
|
|
|
| 9 |
_is_seq2seq = None
|
| 10 |
|
| 11 |
def get_chatbot():
|
| 12 |
+
"""Get or create the chatbot model instance.
|
| 13 |
+
|
| 14 |
+
This function implements a singleton pattern to load and cache the chatbot
|
| 15 |
+
model and tokenizer. It supports both causal language models (like GPT-style
|
| 16 |
+
models) and sequence-to-sequence models (like BlenderBot). The model type
|
| 17 |
+
is automatically detected from the model configuration.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tuple containing:
|
| 21 |
+
- Model: The loaded transformer model (AutoModelForCausalLM or AutoModelForSeq2SeqLM)
|
| 22 |
+
- Tokenizer: The corresponding tokenizer
|
| 23 |
+
- bool: Whether the model is a seq2seq model (True) or causal LM (False)
|
| 24 |
+
|
| 25 |
+
Note:
|
| 26 |
+
- The model ID is determined by the CHAT_MODEL environment variable.
|
| 27 |
+
- Models are loaded with safetensors for secure loading.
|
| 28 |
+
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
|
| 29 |
+
- Sets pad_token to eos_token if pad_token is not configured.
|
| 30 |
+
- Model is cached globally after first load for performance.
|
| 31 |
+
"""
|
| 32 |
global _chatbot, _tokenizer, _is_seq2seq
|
| 33 |
if _chatbot is None:
|
| 34 |
model_id = getenv("CHAT_MODEL")
|
|
|
|
| 65 |
|
| 66 |
@spaces_gpu
|
| 67 |
def chat(message: str, conversation_history: list[dict] | None) -> tuple[str, list[dict]]:
|
| 68 |
+
"""Generate a chatbot response given a user message and conversation history.
|
| 69 |
+
|
| 70 |
+
This function handles conversation with AI chatbots, supporting both modern
|
| 71 |
+
chat models with chat templates (like Qwen, Mistral) and older models
|
| 72 |
+
without templates (like BlenderBot). It manages conversation history and
|
| 73 |
+
formats inputs appropriately based on the model type.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
message: The user's current message as a string.
|
| 77 |
+
conversation_history: Optional list of previous conversation messages.
|
| 78 |
+
Each message is a dict with "role" ("user" or "assistant") and "content".
|
| 79 |
+
If None, starts a new conversation.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Tuple containing:
|
| 83 |
+
- str: The assistant's response message
|
| 84 |
+
- list[dict]: Updated conversation history including the new exchange
|
| 85 |
+
|
| 86 |
+
Note:
|
| 87 |
+
- Supports models with chat templates (uses apply_chat_template)
|
| 88 |
+
- Falls back to manual formatting for models without templates
|
| 89 |
+
- Handles both causal LM and seq2seq model architectures
|
| 90 |
+
- Uses sampling with temperature=0.7 for varied responses
|
| 91 |
+
- Generates up to 256 new tokens
|
| 92 |
+
- Automatically manages conversation context and history
|
| 93 |
+
- Extracts only newly generated text for causal LMs with chat templates
|
| 94 |
+
"""
|
| 95 |
model, tokenizer, is_seq2seq = get_chatbot()
|
| 96 |
|
| 97 |
# Initialize conversation history if this is the first message
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
def create_chatbot_tab():
|
| 178 |
+
"""Create the chatbot tab in the Gradio interface.
|
| 179 |
+
|
| 180 |
+
This function sets up all UI components for the conversational chatbot,
|
| 181 |
+
including:
|
| 182 |
+
- Chatbot component for displaying conversation history
|
| 183 |
+
- Text input box for user messages
|
| 184 |
+
- Send button and Enter key submission support
|
| 185 |
+
- Internal state management for conversation history
|
| 186 |
+
|
| 187 |
+
It also wires up event handlers for both button clicks and Enter key presses,
|
| 188 |
+
and manages the conversion between Gradio's chat format and the internal
|
| 189 |
+
conversation history format.
|
| 190 |
+
"""
|
| 191 |
gr.Markdown("Have a conversation with an AI chatbot.")
|
| 192 |
chatbot_history = gr.State(value=None) # Store the conversation history.
|
| 193 |
chatbot_output = gr.Chatbot(label="Conversation")
|
|
|
|
| 195 |
chatbot_send_button = gr.Button("Send")
|
| 196 |
|
| 197 |
def chat_interface(message: str, history: list | None, conversation_state: list[dict] | None):
|
| 198 |
+
"""Handle chatbot interaction with Gradio chat format.
|
| 199 |
+
|
| 200 |
+
This function serves as the bridge between Gradio's chat interface format
|
| 201 |
+
and the internal chatbot API. It converts formats, handles empty messages,
|
| 202 |
+
and manages state updates.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
message: The user's message string from the input box.
|
| 206 |
+
history: Gradio's chat history format (list of [user_msg, bot_msg] pairs).
|
| 207 |
+
conversation_state: Internal conversation history format (list of dicts).
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Tuple containing:
|
| 211 |
+
- Updated Gradio chat history
|
| 212 |
+
- Updated internal conversation state
|
| 213 |
+
- Empty string (to clear the input field)
|
| 214 |
+
"""
|
| 215 |
if not message.strip():
|
| 216 |
return history, conversation_state, ""
|
| 217 |
response, updated_conversation = chat(message, conversation_state) # Get response from chatbot.
|
image_classification.py
CHANGED
|
@@ -9,6 +9,28 @@ from utils import save_image_to_temp_file, request_image
|
|
| 9 |
|
| 10 |
|
| 11 |
def image_classification(client: InferenceClient, image: Image) -> DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
try:
|
| 13 |
temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
|
| 14 |
classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
|
|
@@ -27,7 +49,17 @@ def image_classification(client: InferenceClient, image: Image) -> DataFrame:
|
|
| 27 |
|
| 28 |
|
| 29 |
def create_image_classification_tab(client: InferenceClient):
|
| 30 |
-
"""Create the image classification tab.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
|
| 32 |
image_classification_url_input = gr.Textbox(label="Image URL")
|
| 33 |
image_classification_image_request_button = gr.Button("Get Image")
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def image_classification(client: InferenceClient, image: Image) -> DataFrame:
|
| 12 |
+
"""Classify an image using Hugging Face Inference API.
|
| 13 |
+
|
| 14 |
+
This function classifies a recyclable item image into categories:
|
| 15 |
+
cardboard, glass, metal, paper, plastic, or other. The image is saved
|
| 16 |
+
to a temporary file since InferenceClient requires a file path rather than
|
| 17 |
+
a PIL Image object directly.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
client: Hugging Face InferenceClient instance for API calls.
|
| 21 |
+
image: PIL Image object to classify.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Pandas DataFrame with two columns:
|
| 25 |
+
- Label: The classification label (e.g., "cardboard", "glass")
|
| 26 |
+
- Probability: The confidence score as a percentage string (e.g., "95.23%")
|
| 27 |
+
|
| 28 |
+
Note:
|
| 29 |
+
- The model ID is determined by the IMAGE_CLASSIFICATION_MODEL environment variable.
|
| 30 |
+
- Uses Trash-Net model for recyclable item classification.
|
| 31 |
+
- Automatically cleans up temporary files after classification.
|
| 32 |
+
- Temporary file is created with format preservation if possible.
|
| 33 |
+
"""
|
| 34 |
try:
|
| 35 |
temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
|
| 36 |
classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
def create_image_classification_tab(client: InferenceClient):
|
| 52 |
+
"""Create the image classification tab in the Gradio interface.
|
| 53 |
+
|
| 54 |
+
This function sets up all UI components for image classification, including:
|
| 55 |
+
- URL input textbox for fetching images from the web
|
| 56 |
+
- Button to retrieve image from URL
|
| 57 |
+
- Image preview component
|
| 58 |
+
- Classify button and output dataframe showing labels and probabilities
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
client: Hugging Face InferenceClient instance to pass to the image_classification function.
|
| 62 |
+
"""
|
| 63 |
gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
|
| 64 |
image_classification_url_input = gr.Textbox(label="Image URL")
|
| 65 |
image_classification_image_request_button = gr.Button("Get Image")
|
image_to_text.py
CHANGED
|
@@ -8,6 +8,25 @@ from utils import get_pytorch_device, spaces_gpu, request_image
|
|
| 8 |
|
| 9 |
@spaces_gpu
|
| 10 |
def image_to_text(image: Image) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL")
|
| 12 |
pytorch_device = get_pytorch_device()
|
| 13 |
processor = AutoProcessor.from_pretrained(image_to_text_model_id)
|
|
@@ -24,7 +43,14 @@ def image_to_text(image: Image) -> list[str]:
|
|
| 24 |
|
| 25 |
|
| 26 |
def create_image_to_text_tab():
|
| 27 |
-
"""Create the image-to-text captioning tab.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
gr.Markdown("Generate a text description of an image.")
|
| 29 |
image_to_text_url_input = gr.Textbox(label="Image URL")
|
| 30 |
image_to_text_image_request_button = gr.Button("Get Image")
|
|
|
|
| 8 |
|
| 9 |
@spaces_gpu
|
| 10 |
def image_to_text(image: Image) -> list[str]:
|
| 11 |
+
"""Generate text captions for an image using BLIP model.
|
| 12 |
+
|
| 13 |
+
This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
|
| 14 |
+
to generate multiple caption candidates for the input image. The model is
|
| 15 |
+
loaded, inference is performed, and then cleaned up to free GPU memory.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
image: PIL Image object to generate captions for.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
List of string captions describing the image.
|
| 22 |
+
|
| 23 |
+
Note:
|
| 24 |
+
- The model ID is determined by the IMAGE_TO_TEXT_MODEL environment variable.
|
| 25 |
+
- Uses safetensors for secure model loading.
|
| 26 |
+
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
|
| 27 |
+
- Cleans up model and GPU memory after inference.
|
| 28 |
+
- Uses beam search with 3 beams, max length 20, min length 5.
|
| 29 |
+
"""
|
| 30 |
image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL")
|
| 31 |
pytorch_device = get_pytorch_device()
|
| 32 |
processor = AutoProcessor.from_pretrained(image_to_text_model_id)
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def create_image_to_text_tab():
|
| 46 |
+
"""Create the image-to-text captioning tab in the Gradio interface.
|
| 47 |
+
|
| 48 |
+
This function sets up all UI components for image captioning, including:
|
| 49 |
+
- URL input textbox for fetching images from the web
|
| 50 |
+
- Button to retrieve image from URL
|
| 51 |
+
- Image preview component
|
| 52 |
+
- Caption button and output list
|
| 53 |
+
"""
|
| 54 |
gr.Markdown("Generate a text description of an image.")
|
| 55 |
image_to_text_url_input = gr.Textbox(label="Image URL")
|
| 56 |
image_to_text_image_request_button = gr.Button("Get Image")
|
text_to_image.py
CHANGED
|
@@ -6,11 +6,30 @@ from huggingface_hub import InferenceClient
|
|
| 6 |
|
| 7 |
|
| 8 |
def text_to_image(client: InferenceClient, prompt: str) -> Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
return client.text_to_image(prompt, model=getenv("TEXT_TO_IMAGE_MODEL"))
|
| 10 |
|
| 11 |
|
| 12 |
def create_text_to_image_tab(client: InferenceClient):
|
| 13 |
-
"""Create the text-to-image generation tab.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
gr.Markdown("Generate an image from a text prompt.")
|
| 15 |
text_to_image_prompt = gr.Textbox(label="Prompt")
|
| 16 |
text_to_image_generate_button = gr.Button("Generate")
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def text_to_image(client: InferenceClient, prompt: str) -> Image:
|
| 9 |
+
"""Generate an image from a text prompt using Hugging Face Inference API.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
client: Hugging Face InferenceClient instance for API calls.
|
| 13 |
+
prompt: Text description of the desired image.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
PIL Image object representing the generated image.
|
| 17 |
+
|
| 18 |
+
Note:
|
| 19 |
+
The model to use is determined by the TEXT_TO_IMAGE_MODEL environment variable.
|
| 20 |
+
"""
|
| 21 |
return client.text_to_image(prompt, model=getenv("TEXT_TO_IMAGE_MODEL"))
|
| 22 |
|
| 23 |
|
| 24 |
def create_text_to_image_tab(client: InferenceClient):
|
| 25 |
+
"""Create the text-to-image generation tab in the Gradio interface.
|
| 26 |
+
|
| 27 |
+
This function sets up all UI components for text-to-image generation,
|
| 28 |
+
including input textbox, generate button, and output image display.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
client: Hugging Face InferenceClient instance to pass to the text_to_image function.
|
| 32 |
+
"""
|
| 33 |
gr.Markdown("Generate an image from a text prompt.")
|
| 34 |
text_to_image_prompt = gr.Textbox(label="Prompt")
|
| 35 |
text_to_image_generate_button = gr.Button("Generate")
|
text_to_speech.py
CHANGED
|
@@ -7,6 +7,27 @@ from utils import spaces_gpu
|
|
| 7 |
|
| 8 |
@spaces_gpu
|
| 9 |
def text_to_speech(text: str) -> tuple[int, bytes]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
narrator = pipeline(
|
| 11 |
"text-to-speech",
|
| 12 |
getenv("TEXT_TO_SPEECH_MODEL"),
|
|
@@ -19,7 +40,11 @@ def text_to_speech(text: str) -> tuple[int, bytes]:
|
|
| 19 |
|
| 20 |
|
| 21 |
def create_text_to_speech_tab():
|
| 22 |
-
"""Create the text-to-speech tab.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
gr.Markdown("Generate speech from text.")
|
| 24 |
text_to_speech_text = gr.Textbox(label="Text")
|
| 25 |
text_to_speech_generate_button = gr.Button("Generate")
|
|
|
|
| 7 |
|
| 8 |
@spaces_gpu
|
| 9 |
def text_to_speech(text: str) -> tuple[int, bytes]:
|
| 10 |
+
"""Convert text to speech audio using a TTS (Text-to-Speech) model.
|
| 11 |
+
|
| 12 |
+
This function uses a transformer pipeline to generate speech audio from
|
| 13 |
+
text input. The model is loaded, inference is performed, and then cleaned
|
| 14 |
+
up to free GPU memory.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
text: Input text string to convert to speech.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tuple containing:
|
| 21 |
+
- int: Sampling rate of the generated audio (e.g., 22050 Hz)
|
| 22 |
+
- bytes: Raw audio data as bytes
|
| 23 |
+
|
| 24 |
+
Note:
|
| 25 |
+
- The model ID is determined by the TEXT_TO_SPEECH_MODEL environment variable.
|
| 26 |
+
- Uses safetensors for secure model loading.
|
| 27 |
+
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
|
| 28 |
+
- Cleans up model and GPU memory after inference.
|
| 29 |
+
- Returns audio in format compatible with Gradio Audio component.
|
| 30 |
+
"""
|
| 31 |
narrator = pipeline(
|
| 32 |
"text-to-speech",
|
| 33 |
getenv("TEXT_TO_SPEECH_MODEL"),
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def create_text_to_speech_tab():
|
| 43 |
+
"""Create the text-to-speech tab in the Gradio interface.
|
| 44 |
+
|
| 45 |
+
This function sets up all UI components for text-to-speech generation,
|
| 46 |
+
including input textbox, generate button, and output audio player.
|
| 47 |
+
"""
|
| 48 |
gr.Markdown("Generate speech from text.")
|
| 49 |
text_to_speech_text = gr.Textbox(label="Text")
|
| 50 |
text_to_speech_generate_button = gr.Button("Generate")
|
utils.py
CHANGED
|
@@ -17,15 +17,48 @@ try:
|
|
| 17 |
except ImportError:
|
| 18 |
# For local development, use a no-op decorator because spaces is not available.
|
| 19 |
def spaces_gpu(func):
|
|
|
|
| 20 |
return func
|
| 21 |
|
| 22 |
def get_pytorch_device() -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm
|
| 24 |
else "xpu" if torch.xpu.is_available() # Intel XPU
|
| 25 |
else "mps" if torch.mps.is_available() # Apple Silicon
|
| 26 |
else "cpu") # gl bro 🫠
|
| 27 |
|
| 28 |
def request_image(url: str) -> Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
try:
|
| 30 |
response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
|
| 31 |
response.raise_for_status()
|
|
@@ -38,6 +71,33 @@ def request_image(url: str) -> Image:
|
|
| 38 |
raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
|
| 39 |
|
| 40 |
def request_audio(url: str) -> tuple[int, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
try:
|
| 42 |
response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
|
| 43 |
response.raise_for_status()
|
|
@@ -53,6 +113,25 @@ def request_audio(url: str) -> tuple[int, np.ndarray]:
|
|
| 53 |
raise gr.Error(f"Failed to load audio file: {str(e)}")
|
| 54 |
|
| 55 |
def save_image_to_temp_file(image: Image) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
image_format = image.format if image.format else 'PNG'
|
| 57 |
format_extension = image_format.lower() if image_format else 'png'
|
| 58 |
temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}")
|
|
@@ -62,6 +141,24 @@ def save_image_to_temp_file(image: Image) -> str:
|
|
| 62 |
return temp_path
|
| 63 |
|
| 64 |
def get_model_sample_rate(model_id: str) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
try:
|
| 66 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 67 |
return processor.feature_extractor.sampling_rate
|
|
@@ -69,6 +166,31 @@ def get_model_sample_rate(model_id: str) -> int:
|
|
| 69 |
return 16000 # Fallback value as most ASR models use 16kHz
|
| 70 |
|
| 71 |
def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
sample_rate, audio_data = audio
|
| 73 |
|
| 74 |
# Convert audio data to a numpy array if it’s bytes
|
|
@@ -86,6 +208,28 @@ def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray
|
|
| 86 |
return audio_array
|
| 87 |
|
| 88 |
def save_audio_to_temp_file(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
audio_array = resample_audio(target_sample_rate, audio)
|
| 90 |
temp_file = NamedTemporaryFile(delete=False, suffix='.wav')
|
| 91 |
temp_path = temp_file.name
|
|
|
|
| 17 |
except ImportError:
|
| 18 |
# For local development, use a no-op decorator because spaces is not available.
|
| 19 |
def spaces_gpu(func):
|
| 20 |
+
"""No-op decorator for local development when spaces module is not available."""
|
| 21 |
return func
|
| 22 |
|
| 23 |
def get_pytorch_device() -> str:
|
| 24 |
+
"""Determine the best available PyTorch device for computation.
|
| 25 |
+
|
| 26 |
+
Checks for available hardware accelerators in priority order:
|
| 27 |
+
1. CUDA (Nvidia GPUs and AMD ROCm)
|
| 28 |
+
2. XPU (Intel GPUs)
|
| 29 |
+
3. MPS (Apple Silicon/Metal Performance Shaders)
|
| 30 |
+
4. CPU (fallback)
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
String device name: "cuda", "xpu", "mps", or "cpu"
|
| 34 |
+
"""
|
| 35 |
return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm
|
| 36 |
else "xpu" if torch.xpu.is_available() # Intel XPU
|
| 37 |
else "mps" if torch.mps.is_available() # Apple Silicon
|
| 38 |
else "cpu") # gl bro 🫠
|
| 39 |
|
| 40 |
def request_image(url: str) -> Image:
|
| 41 |
+
"""Fetch an image from a URL and return it as a PIL Image.
|
| 42 |
+
|
| 43 |
+
Downloads an image from the provided URL and converts it to a PIL Image
|
| 44 |
+
object for processing. Handles various HTTP errors and timeouts gracefully.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
url: HTTP/HTTPS URL pointing to an image file.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
PIL Image object loaded from the URL.
|
| 51 |
+
|
| 52 |
+
Raises:
|
| 53 |
+
gr.Error: If the image cannot be fetched due to:
|
| 54 |
+
- HTTP errors (4xx, 5xx status codes)
|
| 55 |
+
- Network timeouts
|
| 56 |
+
- Other request exceptions
|
| 57 |
+
|
| 58 |
+
Note:
|
| 59 |
+
- Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds)
|
| 60 |
+
- Supports common image formats (JPEG, PNG, GIF, WebP, etc.)
|
| 61 |
+
"""
|
| 62 |
try:
|
| 63 |
response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
|
| 64 |
response.raise_for_status()
|
|
|
|
| 71 |
raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
|
| 72 |
|
| 73 |
def request_audio(url: str) -> tuple[int, np.ndarray]:
|
| 74 |
+
"""Fetch an audio file from a URL and return it as audio data.
|
| 75 |
+
|
| 76 |
+
Downloads an audio file from the provided URL and loads it using librosa,
|
| 77 |
+
which supports many audio formats. Returns the audio data in a format
|
| 78 |
+
compatible with Gradio's Audio component.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
url: HTTP/HTTPS URL pointing to an audio file.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Tuple containing:
|
| 85 |
+
- int: Sample rate of the audio in Hz (e.g., 44100, 22050)
|
| 86 |
+
- np.ndarray: Audio waveform data as a numpy array (float32, normalized)
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
gr.Error: If the audio cannot be fetched or loaded due to:
|
| 90 |
+
- HTTP errors (4xx, 5xx status codes)
|
| 91 |
+
- Network timeouts
|
| 92 |
+
- Unsupported audio formats
|
| 93 |
+
- Other request or audio loading exceptions
|
| 94 |
+
|
| 95 |
+
Note:
|
| 96 |
+
- Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds)
|
| 97 |
+
- Supports many audio formats (MP3, WAV, FLAC, OGG, M4A, etc.)
|
| 98 |
+
- Audio is loaded at its native sample rate (sr=None)
|
| 99 |
+
- Returns normalized float32 audio data suitable for processing
|
| 100 |
+
"""
|
| 101 |
try:
|
| 102 |
response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
|
| 103 |
response.raise_for_status()
|
|
|
|
| 113 |
raise gr.Error(f"Failed to load audio file: {str(e)}")
|
| 114 |
|
| 115 |
def save_image_to_temp_file(image: Image) -> str:
|
| 116 |
+
"""Save a PIL Image to a temporary file on disk.
|
| 117 |
+
|
| 118 |
+
Creates a temporary file with an appropriate extension based on the image's
|
| 119 |
+
format and saves the image to it. This is needed because some APIs (like
|
| 120 |
+
Hugging Face InferenceClient) require file paths rather than PIL Image objects.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
image: PIL Image object to save.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
String path to the temporary file where the image was saved.
|
| 127 |
+
|
| 128 |
+
Note:
|
| 129 |
+
- Preserves the original image format if available
|
| 130 |
+
- Falls back to PNG format if image.format is None
|
| 131 |
+
- Temporary file is not automatically deleted (caller is responsible for cleanup)
|
| 132 |
+
- File extension is determined from the image format
|
| 133 |
+
- Useful for APIs that require local file paths rather than in-memory objects
|
| 134 |
+
"""
|
| 135 |
image_format = image.format if image.format else 'PNG'
|
| 136 |
format_extension = image_format.lower() if image_format else 'png'
|
| 137 |
temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}")
|
|
|
|
| 141 |
return temp_path
|
| 142 |
|
| 143 |
def get_model_sample_rate(model_id: str) -> int:
|
| 144 |
+
"""Get the expected sample rate for an audio processing model.
|
| 145 |
+
|
| 146 |
+
Retrieves the sample rate configuration from a Hugging Face model's
|
| 147 |
+
feature extractor. This is useful for ensuring audio is resampled to
|
| 148 |
+
match the model's expected input format.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
model_id: Hugging Face model identifier (e.g., "openai/whisper-large-v3").
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Integer sample rate in Hz that the model expects (e.g., 16000).
|
| 155 |
+
Defaults to 16000 Hz if the sample rate cannot be determined.
|
| 156 |
+
|
| 157 |
+
Note:
|
| 158 |
+
- Most ASR models use 16kHz sample rate
|
| 159 |
+
- Uses AutoProcessor to access the model's feature extractor configuration
|
| 160 |
+
- Returns a sensible default (16kHz) if the model config cannot be loaded
|
| 161 |
+
"""
|
| 162 |
try:
|
| 163 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 164 |
return processor.feature_extractor.sampling_rate
|
|
|
|
| 166 |
return 16000 # Fallback value as most ASR models use 16kHz
|
| 167 |
|
| 168 |
def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> np.ndarray:
|
| 169 |
+
"""Resample audio data to a target sample rate.
|
| 170 |
+
|
| 171 |
+
Converts audio data to the target sample rate using librosa's resampling.
|
| 172 |
+
Handles both bytes and numpy array input formats, converting bytes to
|
| 173 |
+
float32 numpy arrays as needed.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
target_sample_rate: Desired output sample rate in Hz (e.g., 16000).
|
| 177 |
+
audio: Tuple containing:
|
| 178 |
+
- int: Current sample rate of the audio
|
| 179 |
+
- bytes | np.ndarray: Audio data (can be raw bytes or numpy array)
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Numpy array (float32) containing the resampled audio waveform.
|
| 183 |
+
If sample rates match, returns the audio data unchanged.
|
| 184 |
+
|
| 185 |
+
Raises:
|
| 186 |
+
ValueError: If audio_data is neither bytes nor np.ndarray.
|
| 187 |
+
|
| 188 |
+
Note:
|
| 189 |
+
- Converts bytes to float32 by assuming int16 PCM format
|
| 190 |
+
- Normalizes int16 values to [-1.0, 1.0] range
|
| 191 |
+
- Only resamples if source and target sample rates differ
|
| 192 |
+
- Uses librosa's high-quality resampling algorithm
|
| 193 |
+
"""
|
| 194 |
sample_rate, audio_data = audio
|
| 195 |
|
| 196 |
# Convert audio data to a numpy array if it’s bytes
|
|
|
|
| 208 |
return audio_array
|
| 209 |
|
| 210 |
def save_audio_to_temp_file(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> str:
|
| 211 |
+
"""Resample audio to target sample rate and save to a temporary WAV file.
|
| 212 |
+
|
| 213 |
+
This function resamples audio data to match a target sample rate and saves
|
| 214 |
+
it as a WAV file. This is useful for preparing audio for APIs that require
|
| 215 |
+
specific sample rates and file formats.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
target_sample_rate: Target sample rate in Hz for the output file (e.g., 16000).
|
| 219 |
+
audio: Tuple containing:
|
| 220 |
+
- int: Current sample rate of the input audio
|
| 221 |
+
- bytes | np.ndarray: Audio data to process
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
String path to the temporary WAV file where the audio was saved.
|
| 225 |
+
|
| 226 |
+
Note:
|
| 227 |
+
- Automatically resamples audio if sample rates don't match
|
| 228 |
+
- Saves audio as WAV format (16-bit PCM)
|
| 229 |
+
- Temporary file is not automatically deleted (caller is responsible for cleanup)
|
| 230 |
+
- Audio is normalized and converted to float32 before saving
|
| 231 |
+
- Useful for preparing audio for Hugging Face InferenceClient APIs
|
| 232 |
+
"""
|
| 233 |
audio_array = resample_audio(target_sample_rate, audio)
|
| 234 |
temp_file = NamedTemporaryFile(delete=False, suffix='.wav')
|
| 235 |
temp_path = temp_file.name
|