rtferraz's picture
Add predefined schemas (FINANCE, ECOMMERCE, HEALTHCARE)
c00ac2c verified
"""
Predefined domain schemas for common use cases.
Each schema follows the validated patterns from the research:
- FINANCE_SCHEMA: Based on Nubank nuFormer (arXiv:2507.23267) — 97 special tokens
- ECOMMERCE_SCHEMA: Adapted from ActionPiece (arXiv:2502.13581) + nuFormer patterns
- HEALTHCARE_SCHEMA: Clinical event sequences
"""
from ..schema import DomainSchema, FieldSpec, FieldType
# =============================================================================
# FINANCE SCHEMA — Based on Nubank nuFormer
# sign(2) + amount_bucket(21) + month(12) + dow(7) + dom(31) + hour(24) = 97
# =============================================================================
FINANCE_SCHEMA = DomainSchema(
name="finance",
description=(
"Financial transaction schema following Nubank nuFormer (arXiv:2507.23267). "
"Each transaction = sign + amount bucket + calendar features + text description. "
"~14 tokens per transaction, 2048 context = ~146 transactions."
),
fields=[
FieldSpec(name="amount_sign", field_type=FieldType.SIGN, prefix="AMT_SIGN"),
FieldSpec(name="amount", field_type=FieldType.NUMERICAL_CONTINUOUS, prefix="AMT", n_bins=21),
FieldSpec(name="timestamp", field_type=FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom", "hour"]),
FieldSpec(name="description", field_type=FieldType.TEXT, prefix="DESC"),
],
)
# =============================================================================
# E-COMMERCE SCHEMA — Adapted from ActionPiece + nuFormer patterns
# =============================================================================
ECOMMERCE_SCHEMA = DomainSchema(
name="ecommerce",
description=(
"E-commerce event schema adapted from ActionPiece (arXiv:2502.13581) "
"and nuFormer patterns. Events: view/cart/purchase/return/wishlist. "
"~16 tokens per event, 2048 context = ~128 events."
),
fields=[
FieldSpec(name="event_type", field_type=FieldType.CATEGORICAL_FIXED, prefix="EVT",
categories=["view", "add_to_cart", "purchase", "return", "wishlist"]),
FieldSpec(name="price", field_type=FieldType.NUMERICAL_CONTINUOUS, prefix="PRICE", n_bins=21),
FieldSpec(name="quantity", field_type=FieldType.NUMERICAL_DISCRETE, prefix="QTY", max_value=10),
FieldSpec(name="category", field_type=FieldType.CATEGORICAL_FIXED, prefix="CAT",
categories=[
"electronics", "clothing", "home_garden", "books", "sports",
"toys", "food_grocery", "health_beauty", "automotive", "office",
"pet_supplies", "jewelry", "music", "movies", "games",
"baby", "tools", "arts_crafts", "industrial", "other",
]),
FieldSpec(name="timestamp", field_type=FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom", "hour"]),
FieldSpec(name="product_title", field_type=FieldType.TEXT, prefix="TITLE"),
],
)
# =============================================================================
# HEALTHCARE SCHEMA — Clinical event sequences
# =============================================================================
HEALTHCARE_SCHEMA = DomainSchema(
name="healthcare",
description=(
"Clinical event schema for healthcare sequences. "
"Events: diagnosis/procedure/lab/medication/visit."
),
fields=[
FieldSpec(name="event_type", field_type=FieldType.CATEGORICAL_FIXED, prefix="CLIN",
categories=[
"diagnosis", "procedure", "lab_result", "medication",
"visit_inpatient", "visit_outpatient", "visit_er",
"imaging", "referral", "discharge",
]),
FieldSpec(name="cost", field_type=FieldType.NUMERICAL_CONTINUOUS, prefix="COST", n_bins=21),
FieldSpec(name="severity", field_type=FieldType.CATEGORICAL_FIXED, prefix="SEV",
categories=["low", "moderate", "high", "critical"]),
FieldSpec(name="provider_type", field_type=FieldType.CATEGORICAL_FIXED, prefix="PROV",
categories=[
"pcp", "specialist", "surgeon", "er_physician",
"nurse_practitioner", "therapist", "pharmacist", "other",
]),
FieldSpec(name="timestamp", field_type=FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom"]),
FieldSpec(name="description", field_type=FieldType.TEXT, prefix="DESC"),
],
)