HF-Dataset-Commander / processor.py
broadfield-dev's picture
Update processor.py
9ad0200 verified
import json
import logging
import datasets
import requests
import math
import re
from datasets import load_dataset, get_dataset_config_names, get_dataset_infos
from huggingface_hub import HfApi, DatasetCard, DatasetCardData
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DatasetCommandCenter:
def __init__(self, token=None):
self.token = token
self.api = HfApi(token=token)
self.username=self.api.whoami()['name']
print("######################################")
print(self.username)
print("######################################")
# ==========================================
# 1. METADATA & SCHEMA INSPECTION
# ==========================================
def get_dataset_metadata(self, dataset_id):
"""
Fetches Configs and Splits.
"""
configs = ['default']
splits = ['train', 'test', 'validation']
license_name = "unknown"
try:
# 1. Fetch Configs
try:
found_configs = get_dataset_config_names(dataset_id, token=self.token)
if found_configs:
configs = found_configs
except Exception:
pass
# 2. Fetch Metadata (Splits & License)
try:
selected = configs[0]
infos = get_dataset_infos(dataset_id, token=self.token)
print(infos)
info = None
if selected in infos:
info = infos[selected]
elif 'default' in infos:
info = infos['default']
elif infos:
info = list(infos.values())[0]
if info:
splits = list(info.splits.keys())
license_name = info.license or "unknown"
except Exception:
pass
return {
"status": "success",
"configs": configs,
"splits": splits,
"license_detected": license_name
}
except Exception as e:
return {"status": "error", "message": str(e)}
def get_splits_for_config(self, dataset_id, config_name):
try:
infos = get_dataset_infos(dataset_id, config_name=config_name, token=self.token)
if config_name in infos:
splits = list(infos[config_name].splits.keys())
elif len(infos) > 0:
splits = list(infos.values())[0].splits.keys()
else:
splits = ['train', 'test']
return {"status": "success", "splits": splits}
except:
return {"status": "success", "splits": ['train', 'test', 'validation']}
def _sanitize_for_json(self, obj):
"""
Recursively cleans data for JSON serialization.
"""
if isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return None
return obj
elif isinstance(obj, dict):
return {k: self._sanitize_for_json(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._sanitize_for_json(v) for v in obj]
elif isinstance(obj, (str, int, bool, type(None))):
return obj
else:
return str(obj)
def _flatten_object(self, obj, parent_key='', sep='.'):
"""
Recursively finds keys for the UI dropdowns.
"""
items = {}
# Transparently parse JSON strings
if isinstance(obj, str):
s = obj.strip()
if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
try:
obj = json.loads(s)
except:
pass
if isinstance(obj, dict):
for k, v in obj.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
items.update(self._flatten_object(v, new_key, sep=sep))
elif isinstance(obj, list):
new_key = f"{parent_key}" if parent_key else "list_content"
items[new_key] = "List"
else:
items[parent_key] = type(obj).__name__
return items
def inspect_dataset(self, dataset_id, config, split):
try:
conf = config if config != 'default' else None
ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
sample_rows = []
available_paths = set()
schema_map = {}
for i, row in enumerate(ds_stream):
if i >= 10: break
# CRITICAL FIX: Force Materialization
row = dict(row)
# Clean row for UI
clean_row = self._sanitize_for_json(row)
sample_rows.append(clean_row)
# Schema Discovery
flattened = self._flatten_object(row)
available_paths.update(flattened.keys())
# List Mode Detection
for k, v in row.items():
if k not in schema_map:
schema_map[k] = {"type": "Object"}
val = v
if isinstance(val, str):
try: val = json.loads(val)
except: pass
if isinstance(val, list):
schema_map[k]["type"] = "List"
sorted_paths = sorted(list(available_paths))
schema_tree = {}
for path in sorted_paths:
root = path.split('.')[0]
if root not in schema_tree:
schema_tree[root] = []
schema_tree[root].append(path)
return {
"status": "success",
"samples": sample_rows,
"schema_tree": schema_tree,
"schema": schema_map,
"dataset_id": dataset_id
}
except Exception as e:
return {"status": "error", "message": str(e)}
# ==========================================
# 2. CORE EXTRACTION LOGIC
# ==========================================
def _get_value_by_path(self, obj, path):
"""
Retrieves value. PRIORITY: Direct Key Access (Fastest).
"""
if not path:
return obj
# Handle None/empty path edge cases
if path is None or path == '':
return obj
# 1. Try Direct Access First (handles simple column names)
# This works for dict, UserDict, LazyRow due to duck-typing
try:
# For simple paths (no dots), this is all we need
if '.' not in path:
return obj[path]
except (KeyError, TypeError, AttributeError):
pass
# 2. If direct access failed OR path contains dots, try dot notation
keys = path.split('.')
current = obj
for i, key in enumerate(keys):
if current is None:
return None
try:
# Array/List index access support (e.g. solutions.0.code)
if isinstance(current, list) and key.isdigit():
current = current[int(key)]
else:
# Try dictionary-style access
current = current[key]
except (KeyError, TypeError, IndexError, AttributeError):
return None
# Lazy Parsing: Only parse JSON string if we need to go deeper
is_last_key = (i == len(keys) - 1)
if not is_last_key and isinstance(current, str):
s = current.strip()
if (s.startswith('{') and s.endswith('}')) or (s.startswith('[') and s.endswith(']')):
try:
current = json.loads(s)
except:
return None
return current
def _extract_from_list_logic(self, row, source_col, filter_key, filter_val, target_path):
"""
FROM source_col FIND ITEM WHERE filter_key == filter_val EXTRACT target_path
"""
data = row.get(source_col)
if isinstance(data, str):
try:
data = json.loads(data)
except:
return None
if not isinstance(data, list):
return None
matched_item = None
for item in data:
# String comparison for safety
if str(item.get(filter_key, '')) == str(filter_val):
matched_item = item
break
if matched_item:
return self._get_value_by_path(matched_item, target_path)
return None
def _apply_projection(self, row, recipe):
new_row = {}
# OPTIMIZATION: Only create eval_context if we actually have a Python column.
# This prevents expensive row.copy() calls for Simple Path operations.
eval_context = None
for col_def in recipe['columns']:
t_type = col_def.get('type', 'simple')
target_col = col_def['name']
try:
if t_type == 'simple':
# Fast path - no context needed
new_row[target_col] = self._get_value_by_path(row, col_def['source'])
elif t_type == 'list_search':
# Fast path - no context needed
new_row[target_col] = self._extract_from_list_logic(
row,
col_def['source'],
col_def['filter_key'],
col_def['filter_val'],
col_def['target_key']
)
elif t_type == 'python':
if eval_context is None:
eval_context = row.copy()
eval_context['row'] = row
eval_context['json'] = json
eval_context['re'] = re
eval_context['requests'] = requests
# This evaluates the ENTIRE expression as Python
val = eval(col_def['expression'], {}, eval_context)
new_row[target_col] = val
elif t_type == 'requests':
print(t_type)
# Lazy Context Creation: Only pay the cost if used
eval_context = row.copy()
eval_context['row'] = row
#val = eval(col_def['rpay'], {}, eval_context)
print(col_def['rpay'])
val = json.loads(col_def['rpay'])
print(val)
new_row[target_col] = requests.post(col_def['rurl'], json=val).text
except Exception as e:
raise ValueError(f"Column '{target_col}' failed: {str(e)}")
return new_row
# ==========================================
# 3. DOCUMENTATION (MODEL CARD)
# ==========================================
def _generate_card(self, source_id, target_id, recipe, license_name):
print(source_id)
print(target_id)
card_data = DatasetCardData(
language="en",
license=license_name,
tags=["dataset-command-center", "etl", "generated-dataset"],
base_model=source_id,
)
content = f"""
# {target_id.split('/')[-1]}
This dataset is a transformation of [{source_id}](https://huggingface.co/datasets/{source_id}).
It was generated using the **Hugging Face Dataset Command Center**.
## Transformation Recipe
The following operations were applied to the source data:
| Target Column | Operation Type | Source / Logic |
|---------------|----------------|----------------|
"""
for col in recipe['columns']:
c_type = col.get('type', 'simple')
c_name = col['name']
c_src = col.get('source', '-')
logic = "-"
if c_type == 'simple':
logic = f"Mapped from `{c_src}`"
elif c_type == 'list_search':
logic = f"Get `{col['target_key']}` where `{col['filter_key']} == {col['filter_val']}`"
elif c_type == 'python':
logic = f"Python: `{col.get('expression')}`"
content += f"| **{c_name}** | {c_type} | {logic} |\n"
if recipe.get('filter_rule'):
content += f"\n### Row Filtering\n**Filter Applied:** `{recipe['filter_rule']}`\n"
content += f"\n## Original License\nThis dataset inherits the license: `{license_name}` from the source."
card = DatasetCard.from_template(card_data, content=content)
return card
# ==========================================
# 4. EXECUTION
# ==========================================
def process_and_push(self, source_id, config, split, target_id, recipe, max_rows=None, new_license=None):
logger.info(f"Job started: {source_id} -> {target_id}")
conf = config if config != 'default' else None
def gen():
ds_stream = load_dataset(source_id, name=conf, split=split, streaming=True, token=self.token)
count = 0
for i, row in enumerate(ds_stream):
if max_rows and count >= int(max_rows):
break
# CRITICAL FIX: Force Materialization
row = dict(row)
# 1. Filter
if recipe.get('filter_rule'):
try:
ctx = row.copy()
ctx['row'] = row
ctx['json'] = json
ctx['re'] = re
ctx['requests'] = requests
if not eval(recipe['filter_rule'], {}, ctx):
continue
except Exception as e:
raise ValueError(f"Filter crashed on row {i}: {e}")
# 2. Projection
try:
yield self._apply_projection(row, recipe)
count += 1
except ValueError as ve:
raise ve
except Exception as e:
raise ValueError(f"Unexpected crash on row {i}: {e}")
try:
# 1. Process & Push
new_dataset = datasets.Dataset.from_generator(gen)
new_dataset.push_to_hub(target_id, token=self.token)
# 2. Card
try:
card = self._generate_card(source_id, target_id, recipe, new_license or "unknown")
card.push_to_hub(f'{self.username}/{target_id}', token=self.token)
except Exception as e:
logger.error(f"Failed to push Dataset Card: {e}")
return {"status": "success", "rows_processed": len(new_dataset)}
except Exception as e:
logger.error(f"Job Failed: {e}")
return {"status": "failed", "error": str(e)}
# ==========================================
# 5. PREVIEW
# ==========================================
def preview_transform(self, dataset_id, config, split, recipe):
conf = config if config != 'default' else None
try:
# Load dataset in streaming mode
ds_stream = load_dataset(dataset_id, name=conf, split=split, streaming=True, token=self.token)
processed = []
for i, row in enumerate(ds_stream):
# Stop after 5 successful rows
if len(processed) >= 5:
break
# CRITICAL: Force materialization from LazyRow to standard Dict.
# This fixes the interaction between Streaming datasets and JSON serialization.
row = dict(row)
# --- Filter Logic ---
passed = True
if recipe.get('filter_rule'):
try:
# Create context only for the filter check
ctx = row.copy()
ctx['row'] = row
ctx['json'] = json
ctx['re'] = re
if not eval(recipe['filter_rule'], {}, ctx):
passed = False
except:
# If filter errors out (e.g. missing column), treat as filtered out
passed = False
if passed:
try:
# --- Projection Logic ---
new_row = self._apply_projection(row, recipe)
# --- Sanitization ---
# Convert NaNs, Infinity, and complex objects to prevent browser/Flask crash
clean_new_row = self._sanitize_for_json(new_row)
processed.append(clean_new_row)
except Exception as e:
# Capture specific row errors for the UI
processed.append({"_preview_error": f"Row {i} Error: {str(e)}"})
return processed
except Exception as e:
# Raise global errors (like 404 Dataset Not Found) so the UI sees them
raise e