Spaces:
Running
Running
| """ | |
| Dataset sampling tool for retrieving samples from HuggingFace datasets. | |
| This module provides tools for efficiently sampling data from HuggingFace datasets | |
| with support for different splits, configurable sample sizes, and streaming for large datasets. | |
| """ | |
| import logging | |
| import gradio as gr | |
| from typing import Optional, Dict, Any | |
| from hf_eda_mcp.config import get_config | |
| from hf_eda_mcp.services.dataset_service import get_dataset_service, DatasetServiceError | |
| from hf_eda_mcp.integrations.hf_client import DatasetNotFoundError, AuthenticationError, NetworkError | |
| from hf_eda_mcp.validation import ( | |
| validate_dataset_id, | |
| validate_config_name, | |
| validate_split_name, | |
| validate_sample_size, | |
| ValidationError, | |
| format_validation_error, | |
| ) | |
| from hf_eda_mcp.error_handling import format_error_response, log_error_with_context | |
| logger = logging.getLogger(__name__) | |
| # Default constants (can be overridden by config) | |
| DEFAULT_SAMPLE_SIZE = 10 | |
| VALID_SPLITS = {"train", "validation", "test", "dev", "val"} | |
| def get_dataset_sample( | |
| dataset_id: str, | |
| split: str = "train", | |
| num_samples: int = DEFAULT_SAMPLE_SIZE, | |
| config_name: Optional[str] = None, | |
| streaming: bool = True, | |
| hf_api_token: gr.Header = "", | |
| ) -> Dict[str, Any]: | |
| """ | |
| Retrieve a sample of rows from a HuggingFace dataset. | |
| This function efficiently samples data from datasets with support for different | |
| splits and configurable sample sizes. It uses streaming by default for large | |
| datasets to minimize memory usage and loading time. | |
| Args: | |
| dataset_id: HuggingFace dataset identifier (e.g., 'imdb', 'squad', 'glue') | |
| split: Dataset split to sample from (default: 'train') | |
| num_samples: Number of samples to retrieve (default: 10, max: 10000) | |
| config_name: Optional configuration name for multi-config datasets | |
| streaming: Whether to use streaming mode for efficient loading (default: True) | |
| hf_api_token: Header parsed by Gradio when hf_api_token is provided in MCP configuration headers | |
| Returns: | |
| Dictionary containing sampled data and metadata: | |
| - dataset_id: Original dataset identifier | |
| - config_name: Configuration name used (if any) | |
| - split: Split name sampled from | |
| - num_samples: Actual number of samples returned | |
| - requested_samples: Number of samples originally requested | |
| - data: List of sample dictionaries | |
| - schema: Dictionary describing the dataset features/columns | |
| - sample_info: Additional information about the sampling process | |
| Raises: | |
| ValueError: If inputs are invalid (empty dataset_id, invalid split, etc.) | |
| DatasetNotFoundError: If dataset or split doesn't exist | |
| AuthenticationError: If dataset is private and authentication fails | |
| DatasetServiceError: If sampling fails for other reasons | |
| Example: | |
| >>> # Basic sampling | |
| >>> sample = get_dataset_sample("imdb", split="train", num_samples=5) | |
| >>> print(f"Got {sample['num_samples']} samples from {sample['dataset_id']}") | |
| >>> for i, row in enumerate(sample['data']): | |
| ... print(f"Sample {i+1}: {list(row.keys())}") | |
| >>> # Multi-config dataset sampling | |
| >>> sample = get_dataset_sample("glue", split="validation", | |
| ... num_samples=3, config_name="cola") | |
| >>> print(f"Schema: {sample['schema']}") | |
| """ | |
| # Handle empty strings from Gradio (convert to None) | |
| if config_name == "": | |
| config_name = None | |
| # Input validation using centralized validation | |
| try: | |
| dataset_id = validate_dataset_id(dataset_id) | |
| config_name = validate_config_name(config_name) | |
| split = validate_split_name(split) | |
| num_samples = validate_sample_size(num_samples, "num_samples") | |
| except ValidationError as e: | |
| logger.error(f"Validation error: {format_validation_error(e)}") | |
| raise ValueError(format_validation_error(e)) | |
| context = { | |
| "dataset_id": dataset_id, | |
| "split": split, | |
| "num_samples": num_samples, | |
| "config_name": config_name, | |
| "operation": "get_dataset_sample" | |
| } | |
| logger.info( | |
| f"Sampling {num_samples} rows from dataset: {dataset_id}, " | |
| f"split: {split}" + (f", config: {config_name}" if config_name else "") | |
| ) | |
| try: | |
| # Get dataset service and load sample | |
| service = get_dataset_service(hf_api_token=hf_api_token) | |
| sample_data = service.load_dataset_sample( | |
| dataset_id=dataset_id, | |
| split=split, | |
| num_samples=num_samples, | |
| config_name=config_name, | |
| streaming=streaming, | |
| ) | |
| # Enhance the response with additional metadata | |
| config = get_config() | |
| sample_data["sample_info"] = { | |
| "streaming_used": streaming, | |
| "sampling_strategy": "sequential_head", # We take first N samples | |
| "max_sample_size": config.max_sample_size, | |
| "truncated": sample_data["num_samples"] < sample_data["requested_samples"], | |
| } | |
| # Add data preview information | |
| if sample_data["data"]: | |
| first_sample = sample_data["data"][0] | |
| sample_data["sample_info"]["preview"] = { | |
| "columns": list(first_sample.keys()) | |
| if isinstance(first_sample, dict) | |
| else [], | |
| "first_sample_types": { | |
| k: type(v).__name__ for k, v in first_sample.items() | |
| } | |
| if isinstance(first_sample, dict) | |
| else {}, | |
| } | |
| # Add summary | |
| sample_data["summary"] = _generate_sample_summary(sample_data) | |
| logger.info( | |
| f"Successfully sampled {sample_data['num_samples']} rows from {dataset_id}" | |
| ) | |
| return sample_data | |
| except DatasetNotFoundError as e: | |
| log_error_with_context(e, context, level=logging.WARNING) | |
| error_response = format_error_response(e, context) | |
| logger.info(f"Dataset/split not found suggestions: {error_response.get('suggestions', [])}") | |
| raise | |
| except AuthenticationError as e: | |
| log_error_with_context(e, context, level=logging.WARNING) | |
| error_response = format_error_response(e, context) | |
| logger.info(f"Authentication error guidance: {error_response.get('suggestions', [])}") | |
| raise | |
| except NetworkError as e: | |
| log_error_with_context(e, context) | |
| error_response = format_error_response(e, context) | |
| logger.info(f"Network error guidance: {error_response.get('suggestions', [])}") | |
| raise | |
| except Exception as e: | |
| log_error_with_context(e, context) | |
| raise DatasetServiceError(f"Failed to sample dataset: {str(e)}") from e | |
| # def get_dataset_sample_with_indices( | |
| # dataset_id: str, | |
| # indices: List[int], | |
| # split: str = "train", | |
| # config_name: Optional[str] = None, | |
| # ) -> Dict[str, Any]: | |
| # """ | |
| # Retrieve specific samples by their indices from a HuggingFace dataset. | |
| # This function allows for targeted sampling by specifying exact row indices. | |
| # Note: This requires loading the dataset in non-streaming mode. | |
| # Args: | |
| # dataset_id: HuggingFace dataset identifier | |
| # indices: List of row indices to retrieve | |
| # split: Dataset split to sample from (default: 'train') | |
| # config_name: Optional configuration name for multi-config datasets | |
| # Returns: | |
| # Dictionary containing the requested samples and metadata | |
| # Raises: | |
| # ValueError: If inputs are invalid | |
| # DatasetServiceError: If sampling fails | |
| # """ | |
| # # Handle empty strings from Gradio (convert to None) | |
| # if config_name == "": | |
| # config_name = None | |
| # # Input validation using centralized validation | |
| # try: | |
| # dataset_id = validate_dataset_id(dataset_id) | |
| # config_name = validate_config_name(config_name) | |
| # split = validate_split_name(split) | |
| # indices = validate_indices(indices) | |
| # except ValidationError as e: | |
| # logger.error(f"Validation error: {format_validation_error(e)}") | |
| # raise ValueError(format_validation_error(e)) | |
| # logger.info(f"Sampling {len(indices)} specific indices from dataset: {dataset_id}") | |
| # try: | |
| # from datasets import load_dataset | |
| # # Load dataset without streaming to access by index | |
| # dataset = load_dataset( | |
| # dataset_id, name=config_name, split=split, streaming=False | |
| # ) | |
| # # Validate indices are within bounds | |
| # max_index = max(indices) | |
| # if max_index >= len(dataset): | |
| # raise ValueError( | |
| # f"Index {max_index} is out of bounds for dataset with {len(dataset)} rows" | |
| # ) | |
| # # Get samples by indices | |
| # samples = [dataset[i] for i in indices] | |
| # # Get dataset info for schema | |
| # service = get_dataset_service(hf_api_token=hf_api_token) | |
| # dataset_info = service.load_dataset_info(dataset_id, config_name) | |
| # # Prepare response | |
| # sample_data = { | |
| # "dataset_id": dataset_id, | |
| # "config_name": config_name, | |
| # "split": split, | |
| # "num_samples": len(samples), | |
| # "requested_indices": indices, | |
| # "data": samples, | |
| # "schema": dataset_info.get("features", {}), | |
| # "sample_info": { | |
| # "sampling_strategy": "by_indices", | |
| # "streaming_used": False, | |
| # "indices_requested": len(indices), | |
| # }, | |
| # } | |
| # sample_data["summary"] = _generate_sample_summary(sample_data) | |
| # return sample_data | |
| # except Exception as e: | |
| # logger.error(f"Failed to sample by indices from {dataset_id}: {str(e)}") | |
| # raise DatasetServiceError(f"Failed to sample by indices: {str(e)}") | |
| def _generate_sample_summary(sample_data: Dict[str, Any]) -> str: | |
| """Generate a human-readable summary of the sample data.""" | |
| summary_parts = [] | |
| # Basic info | |
| summary_parts.append(f"Dataset: {sample_data.get('dataset_id', 'Unknown')}") | |
| summary_parts.append(f"Split: {sample_data.get('split', 'Unknown')}") | |
| if sample_data.get("config_name"): | |
| summary_parts.append(f"Config: {sample_data['config_name']}") | |
| # Sample info | |
| num_samples = sample_data.get("num_samples", 0) | |
| requested = sample_data.get("requested_samples", num_samples) | |
| if num_samples == requested: | |
| summary_parts.append(f"Samples: {num_samples}") | |
| else: | |
| summary_parts.append(f"Samples: {num_samples}/{requested} (truncated)") | |
| # Schema info | |
| schema = sample_data.get("schema", {}) | |
| if schema: | |
| summary_parts.append(f"Columns: {len(schema)}") | |
| # Sampling strategy | |
| sample_info = sample_data.get("sample_info", {}) | |
| strategy = sample_info.get("sampling_strategy", "unknown") | |
| if strategy == "by_indices": | |
| summary_parts.append("Strategy: by indices") | |
| elif strategy == "sequential_head": | |
| summary_parts.append("Strategy: first N rows") | |
| return " | ".join(summary_parts) | |