ane-kan-runtime / scripts /build_python_graphql_datasets.py
JohnGenetica's picture
Deploy ANE KAN runtime Space
201cf4d verified
"""Build Python code generation + GraphQL generation training datasets.
Each pair: question β†’ gold_output with multi-dimensional reward signals.
Usage:
python3 scripts/build_python_graphql_datasets.py
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, List
RESULTS_DIR = Path("training/kan_bench_results")
SOTA_PATH = RESULTS_DIR / "sota_training_data.json"
def _reward(syntax=1.0, pythonic=1.0, type_correct=1.0, tests=1.0, complexity=1.0) -> Dict[str, float]:
return {"syntax_valid": syntax, "pythonic": pythonic, "type_correct": type_correct,
"passes_tests": tests, "complexity_appropriate": complexity,
"composite": round((syntax + pythonic + type_correct + tests + complexity) / 5, 3)}
def _gql_reward(schema=1.0, type_correct=1.0, efficiency=1.0, syntax=1.0) -> Dict[str, float]:
return {"schema_valid": schema, "type_correct": type_correct,
"efficiency": efficiency, "syntax_valid": syntax,
"composite": round((schema + type_correct + efficiency + syntax) / 4, 3)}
# ── Python Code Generation ──────────────────────────────────────────────────
def _build_python_data_processing() -> List[Dict]:
"""Data processing / pandas patterns."""
pairs = []
templates = [
("Write a function to calculate moving average of a list",
'def moving_average(data: list[float], window: int) -> list[float]:\n """Calculate moving average with given window size."""\n if window <= 0 or window > len(data):\n return []\n result = []\n for i in range(len(data) - window + 1):\n avg = sum(data[i:i + window]) / window\n result.append(round(avg, 4))\n return result',
"data_processing"),
("Write a function to merge two sorted lists",
'def merge_sorted(a: list[int], b: list[int]) -> list[int]:\n """Merge two sorted lists into one sorted list."""\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] <= b[j]:\n result.append(a[i])\n i += 1\n else:\n result.append(b[j])\n j += 1\n result.extend(a[i:])\n result.extend(b[j:])\n return result',
"data_processing"),
("Write a function to group items by a key function",
'from collections import defaultdict\nfrom typing import Callable, TypeVar, Hashable\n\nT = TypeVar("T")\nK = TypeVar("K", bound=Hashable)\n\ndef group_by(items: list[T], key_fn: Callable[[T], K]) -> dict[K, list[T]]:\n """Group items by the result of key_fn."""\n groups: dict[K, list[T]] = defaultdict(list)\n for item in items:\n groups[key_fn(item)].append(item)\n return dict(groups)',
"data_processing"),
("Write a function to flatten nested dictionaries",
'def flatten_dict(d: dict, prefix: str = "", sep: str = ".") -> dict[str, any]:\n """Flatten nested dict with dot-separated keys."""\n result = {}\n for key, value in d.items():\n new_key = f"{prefix}{sep}{key}" if prefix else key\n if isinstance(value, dict):\n result.update(flatten_dict(value, new_key, sep))\n else:\n result[new_key] = value\n return result',
"data_processing"),
("Write a function to deduplicate preserving order",
'def deduplicate(items: list) -> list:\n """Remove duplicates while preserving insertion order."""\n seen = set()\n result = []\n for item in items:\n if item not in seen:\n seen.add(item)\n result.append(item)\n return result',
"data_processing"),
("Write a CSV parser that handles quoted fields",
'def parse_csv_line(line: str, delimiter: str = ",") -> list[str]:\n """Parse a single CSV line handling quoted fields with commas."""\n fields = []\n current = []\n in_quotes = False\n for char in line:\n if char == \'"\' and not in_quotes:\n in_quotes = True\n elif char == \'"\' and in_quotes:\n in_quotes = False\n elif char == delimiter and not in_quotes:\n fields.append("".join(current).strip())\n current = []\n else:\n current.append(char)\n fields.append("".join(current).strip())\n return fields',
"data_processing"),
]
for q, code, cat in templates:
pairs.append({"source": q, "target": code, "dialect": "python",
"category": cat, "reward_signals": _reward()})
return pairs
def _build_python_async() -> List[Dict]:
"""Async/concurrency patterns."""
pairs = []
templates = [
("Write an async function to fetch multiple URLs concurrently",
'import asyncio\nimport aiohttp\n\nasync def fetch_all(urls: list[str], timeout: int = 30) -> list[dict]:\n """Fetch multiple URLs concurrently and return results."""\n async def fetch_one(session: aiohttp.ClientSession, url: str) -> dict:\n try:\n async with session.get(url, timeout=aiohttp.ClientTimeout(total=timeout)) as resp:\n return {"url": url, "status": resp.status, "body": await resp.text()}\n except Exception as e:\n return {"url": url, "status": -1, "error": str(e)}\n\n async with aiohttp.ClientSession() as session:\n tasks = [fetch_one(session, url) for url in urls]\n return await asyncio.gather(*tasks)',
"async"),
("Write a rate limiter using asyncio semaphore",
'import asyncio\nfrom typing import Callable, Awaitable, TypeVar\n\nT = TypeVar("T")\n\nclass RateLimiter:\n """Limit concurrent async operations."""\n\n def __init__(self, max_concurrent: int = 10):\n self._semaphore = asyncio.Semaphore(max_concurrent)\n\n async def execute(self, fn: Callable[..., Awaitable[T]], *args, **kwargs) -> T:\n async with self._semaphore:\n return await fn(*args, **kwargs)',
"async"),
("Write a producer-consumer pattern with asyncio queue",
'import asyncio\nfrom typing import Any, Callable, Awaitable\n\nasync def producer_consumer(\n items: list[Any],\n process_fn: Callable[[Any], Awaitable[Any]],\n n_consumers: int = 5,\n) -> list[Any]:\n """Process items with N concurrent consumers."""\n queue: asyncio.Queue = asyncio.Queue()\n results: list[Any] = []\n\n for item in items:\n await queue.put(item)\n\n async def consumer():\n while not queue.empty():\n try:\n item = queue.get_nowait()\n except asyncio.QueueEmpty:\n break\n result = await process_fn(item)\n results.append(result)\n queue.task_done()\n\n consumers = [asyncio.create_task(consumer()) for _ in range(n_consumers)]\n await asyncio.gather(*consumers)\n return results',
"async"),
]
for q, code, cat in templates:
pairs.append({"source": q, "target": code, "dialect": "python",
"category": cat, "reward_signals": _reward()})
return pairs
def _build_python_design_patterns() -> List[Dict]:
"""Design patterns in Python."""
pairs = []
templates = [
("Implement the Observer pattern in Python",
'from abc import ABC, abstractmethod\nfrom typing import Any\n\nclass Observer(ABC):\n @abstractmethod\n def update(self, event: str, data: Any) -> None: ...\n\nclass Subject:\n def __init__(self):\n self._observers: list[Observer] = []\n\n def attach(self, observer: Observer) -> None:\n self._observers.append(observer)\n\n def detach(self, observer: Observer) -> None:\n self._observers.remove(observer)\n\n def notify(self, event: str, data: Any = None) -> None:\n for observer in self._observers:\n observer.update(event, data)',
"design_pattern"),
("Implement the Strategy pattern in Python",
'from abc import ABC, abstractmethod\nfrom typing import TypeVar\n\nT = TypeVar("T")\n\nclass Strategy(ABC):\n @abstractmethod\n def execute(self, data: list[float]) -> float: ...\n\nclass MeanStrategy(Strategy):\n def execute(self, data: list[float]) -> float:\n return sum(data) / len(data) if data else 0.0\n\nclass MedianStrategy(Strategy):\n def execute(self, data: list[float]) -> float:\n if not data:\n return 0.0\n s = sorted(data)\n n = len(s)\n return (s[n // 2] + s[(n - 1) // 2]) / 2\n\nclass Aggregator:\n def __init__(self, strategy: Strategy):\n self._strategy = strategy\n\n def aggregate(self, data: list[float]) -> float:\n return self._strategy.execute(data)',
"design_pattern"),
("Implement a builder pattern for configuration objects",
'from dataclasses import dataclass, field\nfrom typing import Optional\n\n@dataclass(frozen=True)\nclass Config:\n host: str\n port: int\n database: str\n user: str\n password: str\n pool_size: int = 5\n timeout: int = 30\n ssl: bool = True\n\nclass ConfigBuilder:\n def __init__(self):\n self._host = "localhost"\n self._port = 5432\n self._database = "default"\n self._user = "admin"\n self._password = ""\n self._pool_size = 5\n self._timeout = 30\n self._ssl = True\n\n def host(self, h: str) -> "ConfigBuilder":\n self._host = h\n return self\n\n def port(self, p: int) -> "ConfigBuilder":\n self._port = p\n return self\n\n def database(self, d: str) -> "ConfigBuilder":\n self._database = d\n return self\n\n def credentials(self, user: str, password: str) -> "ConfigBuilder":\n self._user = user\n self._password = password\n return self\n\n def pool_size(self, n: int) -> "ConfigBuilder":\n self._pool_size = n\n return self\n\n def build(self) -> Config:\n return Config(\n host=self._host, port=self._port, database=self._database,\n user=self._user, password=self._password,\n pool_size=self._pool_size, timeout=self._timeout, ssl=self._ssl,\n )',
"design_pattern"),
("Implement a factory pattern for creating database connections",
'from abc import ABC, abstractmethod\n\nclass Connection(ABC):\n @abstractmethod\n def execute(self, query: str) -> list[dict]: ...\n\n @abstractmethod\n def close(self) -> None: ...\n\nclass PostgresConnection(Connection):\n def __init__(self, dsn: str):\n self._dsn = dsn\n\n def execute(self, query: str) -> list[dict]:\n return [] # placeholder\n\n def close(self) -> None:\n pass\n\nclass SnowflakeConnection(Connection):\n def __init__(self, account: str, user: str, password: str):\n self._account = account\n\n def execute(self, query: str) -> list[dict]:\n return []\n\n def close(self) -> None:\n pass\n\nclass ConnectionFactory:\n _registry: dict[str, type[Connection]] = {\n "postgres": PostgresConnection,\n "snowflake": SnowflakeConnection,\n }\n\n @classmethod\n def create(cls, db_type: str, **kwargs) -> Connection:\n conn_class = cls._registry.get(db_type)\n if not conn_class:\n raise ValueError(f"Unknown db type: {db_type}")\n return conn_class(**kwargs)',
"design_pattern"),
]
for q, code, cat in templates:
pairs.append({"source": q, "target": code, "dialect": "python",
"category": cat, "reward_signals": _reward()})
return pairs
def _build_python_algorithms() -> List[Dict]:
"""Core algorithms."""
pairs = []
templates = [
("Implement binary search",
'def binary_search(arr: list[int], target: int) -> int:\n """Return index of target in sorted array, or -1 if not found."""\n lo, hi = 0, len(arr) - 1\n while lo <= hi:\n mid = (lo + hi) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n lo = mid + 1\n else:\n hi = mid - 1\n return -1',
"algorithm"),
("Implement topological sort using DFS",
'def topological_sort(graph: dict[str, list[str]]) -> list[str]:\n """Topological sort of a DAG represented as adjacency list."""\n visited: set[str] = set()\n result: list[str] = []\n\n def dfs(node: str) -> None:\n if node in visited:\n return\n visited.add(node)\n for neighbor in graph.get(node, []):\n dfs(neighbor)\n result.append(node)\n\n for node in graph:\n dfs(node)\n result.reverse()\n return result',
"algorithm"),
("Implement LRU cache from scratch",
'from collections import OrderedDict\nfrom typing import TypeVar, Hashable\n\nK = TypeVar("K", bound=Hashable)\nV = TypeVar("V")\n\nclass LRUCache:\n """Least Recently Used cache with O(1) get/put."""\n\n def __init__(self, capacity: int):\n self._capacity = capacity\n self._cache: OrderedDict = OrderedDict()\n\n def get(self, key: K) -> V | None:\n if key not in self._cache:\n return None\n self._cache.move_to_end(key)\n return self._cache[key]\n\n def put(self, key: K, value: V) -> None:\n if key in self._cache:\n self._cache.move_to_end(key)\n self._cache[key] = value\n if len(self._cache) > self._capacity:\n self._cache.popitem(last=False)',
"algorithm"),
("Implement Dijkstra's shortest path",
'import heapq\n\ndef dijkstra(graph: dict[str, list[tuple[str, float]]], start: str) -> dict[str, float]:\n """Shortest paths from start using Dijkstra. graph: {node: [(neighbor, weight)]}."""\n dist: dict[str, float] = {start: 0.0}\n pq: list[tuple[float, str]] = [(0.0, start)]\n\n while pq:\n d, u = heapq.heappop(pq)\n if d > dist.get(u, float("inf")):\n continue\n for v, w in graph.get(u, []):\n new_dist = d + w\n if new_dist < dist.get(v, float("inf")):\n dist[v] = new_dist\n heapq.heappush(pq, (new_dist, v))\n\n return dist',
"algorithm"),
("Implement trie data structure",
'class TrieNode:\n def __init__(self):\n self.children: dict[str, "TrieNode"] = {}\n self.is_end: bool = False\n\nclass Trie:\n def __init__(self):\n self.root = TrieNode()\n\n def insert(self, word: str) -> None:\n node = self.root\n for ch in word:\n if ch not in node.children:\n node.children[ch] = TrieNode()\n node = node.children[ch]\n node.is_end = True\n\n def search(self, word: str) -> bool:\n node = self._find(word)\n return node is not None and node.is_end\n\n def starts_with(self, prefix: str) -> bool:\n return self._find(prefix) is not None\n\n def _find(self, prefix: str) -> TrieNode | None:\n node = self.root\n for ch in prefix:\n if ch not in node.children:\n return None\n node = node.children[ch]\n return node',
"algorithm"),
]
for q, code, cat in templates:
pairs.append({"source": q, "target": code, "dialect": "python",
"category": cat, "reward_signals": _reward()})
return pairs
def _build_python_testing() -> List[Dict]:
"""Testing patterns."""
pairs = []
templates = [
("Write pytest tests for a calculator class",
'import pytest\n\nclass Calculator:\n def add(self, a: float, b: float) -> float:\n return a + b\n\n def divide(self, a: float, b: float) -> float:\n if b == 0:\n raise ValueError("Cannot divide by zero")\n return a / b\n\nclass TestCalculator:\n @pytest.fixture\n def calc(self) -> Calculator:\n return Calculator()\n\n def test_add(self, calc: Calculator) -> None:\n assert calc.add(2, 3) == 5\n assert calc.add(-1, 1) == 0\n assert calc.add(0, 0) == 0\n\n def test_divide(self, calc: Calculator) -> None:\n assert calc.divide(10, 2) == 5.0\n assert calc.divide(7, 2) == 3.5\n\n def test_divide_by_zero(self, calc: Calculator) -> None:\n with pytest.raises(ValueError, match="Cannot divide by zero"):\n calc.divide(1, 0)\n\n @pytest.mark.parametrize("a,b,expected", [(1, 1, 2), (0, 0, 0), (-1, -1, -2)])\n def test_add_parametrized(self, calc: Calculator, a, b, expected) -> None:\n assert calc.add(a, b) == expected',
"testing"),
("Write a mock-based test for an API client",
'from unittest.mock import AsyncMock, patch\nimport pytest\n\nclass APIClient:\n def __init__(self, base_url: str):\n self.base_url = base_url\n\n async def get_user(self, user_id: str) -> dict:\n import aiohttp\n async with aiohttp.ClientSession() as session:\n async with session.get(f"{self.base_url}/users/{user_id}") as resp:\n return await resp.json()\n\n@pytest.mark.asyncio\nasync def test_get_user():\n client = APIClient("https://api.example.com")\n mock_response = {"id": "123", "name": "Alice"}\n\n with patch("aiohttp.ClientSession") as mock_session:\n mock_resp = AsyncMock()\n mock_resp.json = AsyncMock(return_value=mock_response)\n mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)\n mock_resp.__aexit__ = AsyncMock(return_value=False)\n\n mock_get = AsyncMock(return_value=mock_resp)\n mock_session_inst = AsyncMock()\n mock_session_inst.get = mock_get\n mock_session_inst.__aenter__ = AsyncMock(return_value=mock_session_inst)\n mock_session_inst.__aexit__ = AsyncMock(return_value=False)\n mock_session.return_value = mock_session_inst\n\n result = await client.get_user("123")\n assert result == mock_response',
"testing"),
]
for q, code, cat in templates:
pairs.append({"source": q, "target": code, "dialect": "python",
"category": cat, "reward_signals": _reward()})
return pairs
def _build_python_mistakes() -> List[Dict]:
"""Common Python mistakes."""
pairs = []
mistakes = [
("Write function with default list parameter",
'def append_to(item, target=[]):\n target.append(item)\n return target',
'def append_to(item, target: list | None = None) -> list:\n if target is None:\n target = []\n target.append(item)\n return target',
"Mutable default arguments are shared across calls"),
("Write bare except handler",
'try:\n result = process(data)\nexcept:\n pass',
'try:\n result = process(data)\nexcept ValueError as e:\n logger.warning("Invalid data: %s", e)\n result = default_value',
"Never use bare except β€” catch specific exceptions"),
("String concatenation in a loop",
'def build_report(items):\n result = ""\n for item in items:\n result += str(item) + "\\n"\n return result',
'def build_report(items: list) -> str:\n return "\\n".join(str(item) for item in items)',
"Use str.join() instead of += in loops for O(n) vs O(nΒ²)"),
("Not using context manager for file",
'def read_file(path):\n f = open(path)\n data = f.read()\n f.close()\n return data',
'def read_file(path: str) -> str:\n with open(path) as f:\n return f.read()',
"Always use context managers (with statement) for file I/O"),
]
for q, bad, good, explanation in mistakes:
pairs.append({"source": q, "target": good, "dialect": "python",
"category": "mistake_correction",
"reward_signals": _reward()})
pairs.append({"source": q, "target": bad, "dialect": "python",
"category": "common_mistake", "mistake_explanation": explanation,
"reward_signals": _reward(syntax=0.8, pythonic=0.0, complexity=0.3)})
return pairs
# ── GraphQL Dataset ──────────────────────────────────────────────────────────
def _build_graphql_queries() -> List[Dict]:
"""GraphQL query patterns."""
pairs = []
templates = [
("Get user by ID with their posts",
'query GetUser($userId: ID!) {\n user(id: $userId) {\n id\n name\n email\n posts(first: 10, orderBy: CREATED_AT_DESC) {\n edges {\n node {\n id\n title\n content\n createdAt\n }\n }\n pageInfo {\n hasNextPage\n endCursor\n }\n }\n }\n}',
"query"),
("Search products with filtering and pagination",
'query SearchProducts($query: String!, $category: Category, $first: Int = 20, $after: String) {\n searchProducts(query: $query, filter: { category: $category }, first: $first, after: $after) {\n edges {\n node {\n id\n name\n price\n category\n rating\n reviewCount\n }\n }\n totalCount\n pageInfo {\n hasNextPage\n endCursor\n }\n }\n}',
"query"),
("Get dashboard analytics data",
'query DashboardAnalytics($dateRange: DateRangeInput!) {\n analytics(dateRange: $dateRange) {\n totalRevenue\n orderCount\n averageOrderValue\n conversionRate\n topProducts(limit: 5) {\n product {\n id\n name\n }\n revenue\n unitsSold\n }\n revenueByDay {\n date\n amount\n }\n }\n}',
"query"),
("Get Neo4j graph data with Cypher resolver",
'query GetMovieNetwork($movieTitle: String!) {\n movies(where: { title: $movieTitle }) {\n title\n released\n actors {\n name\n born\n }\n directors {\n name\n }\n similarMovies @cypher(statement: """\n MATCH (this)<-[:ACTED_IN]-(:Person)-[:ACTED_IN]->(other:Movie)\n WHERE other <> this\n RETURN DISTINCT other\n LIMIT 5\n """) {\n title\n released\n }\n }\n}',
"cypher_resolver"),
]
for q, gql, cat in templates:
pairs.append({"source": q, "target": gql, "dialect": "graphql",
"category": cat, "reward_signals": _gql_reward()})
return pairs
def _build_graphql_mutations() -> List[Dict]:
"""GraphQL mutation patterns."""
pairs = []
templates = [
("Create a new user account",
'mutation CreateUser($input: CreateUserInput!) {\n createUser(input: $input) {\n user {\n id\n name\n email\n createdAt\n }\n errors {\n field\n message\n }\n }\n}',
"mutation"),
("Place an order with multiple items",
'mutation PlaceOrder($input: PlaceOrderInput!) {\n placeOrder(input: $input) {\n order {\n id\n status\n totalAmount\n items {\n product {\n id\n name\n }\n quantity\n unitPrice\n }\n shippingAddress {\n street\n city\n state\n zipCode\n }\n }\n errors {\n field\n message\n }\n }\n}',
"mutation"),
("Update user profile with optimistic locking",
'mutation UpdateProfile($id: ID!, $input: UpdateProfileInput!, $version: Int!) {\n updateProfile(id: $id, input: $input, expectedVersion: $version) {\n profile {\n id\n displayName\n bio\n avatarUrl\n version\n }\n errors {\n field\n message\n code\n }\n }\n}',
"mutation"),
("Create Neo4j relationship via GraphQL",
'mutation ConnectActorToMovie($actorName: String!, $movieTitle: String!, $role: String!) {\n createActedInRelationship(\n input: {\n actor: { where: { name: $actorName } }\n movie: { where: { title: $movieTitle } }\n edge: { role: $role }\n }\n ) {\n actors {\n name\n }\n movies {\n title\n }\n }\n}',
"mutation"),
]
for q, gql, cat in templates:
pairs.append({"source": q, "target": gql, "dialect": "graphql",
"category": cat, "reward_signals": _gql_reward()})
return pairs
def _build_graphql_subscriptions() -> List[Dict]:
"""GraphQL subscription patterns."""
pairs = []
templates = [
("Subscribe to order status updates",
'subscription OrderUpdates($orderId: ID!) {\n orderStatusChanged(orderId: $orderId) {\n order {\n id\n status\n updatedAt\n estimatedDelivery\n }\n previousStatus\n newStatus\n }\n}',
"subscription"),
("Subscribe to real-time sensor alerts",
'subscription SensorAlerts($deviceIds: [ID!]!, $minSeverity: AlertSeverity = WARNING) {\n sensorAlert(deviceIds: $deviceIds, minSeverity: $minSeverity) {\n alert {\n id\n deviceId\n severity\n message\n reading {\n sensorType\n value\n unit\n timestamp\n }\n }\n }\n}',
"subscription"),
]
for q, gql, cat in templates:
pairs.append({"source": q, "target": gql, "dialect": "graphql",
"category": cat, "reward_signals": _gql_reward()})
return pairs
def _build_graphql_fragments() -> List[Dict]:
"""Fragment and directive patterns."""
pairs = []
templates = [
("Use fragments for reusable user fields",
'fragment UserFields on User {\n id\n name\n email\n avatarUrl\n}\n\nfragment UserWithPosts on User {\n ...UserFields\n posts(first: 5) {\n edges {\n node {\n id\n title\n createdAt\n }\n }\n }\n}\n\nquery GetUsers {\n users(first: 20) {\n edges {\n node {\n ...UserWithPosts\n }\n }\n }\n}',
"fragment"),
("Conditional fields with directives",
'query GetProduct($id: ID!, $includeReviews: Boolean!, $includeInventory: Boolean!) {\n product(id: $id) {\n id\n name\n price\n description\n reviews @include(if: $includeReviews) {\n edges {\n node {\n rating\n comment\n author {\n name\n }\n }\n }\n }\n inventory @include(if: $includeInventory) {\n warehouse\n quantity\n lastUpdated\n }\n legacyField @deprecated(reason: "Use newField instead")\n }\n}',
"directive"),
]
for q, gql, cat in templates:
pairs.append({"source": q, "target": gql, "dialect": "graphql",
"category": cat, "reward_signals": _gql_reward()})
return pairs
def _build_graphql_federation() -> List[Dict]:
"""Apollo Federation patterns."""
pairs = []
templates = [
("Define federated product type with key",
'type Product @key(fields: "id") {\n id: ID!\n name: String!\n price: Float!\n category: Category!\n}\n\nextend type Query {\n product(id: ID!): Product\n products(first: Int, after: String, filter: ProductFilter): ProductConnection!\n}',
"federation_schema"),
("Extend product type from another service",
'type Product @key(fields: "id") @extends {\n id: ID! @external\n reviews: [Review!]!\n averageRating: Float!\n reviewCount: Int!\n}\n\ntype Review {\n id: ID!\n rating: Int!\n comment: String\n author: User!\n createdAt: DateTime!\n}',
"federation_extend"),
("Query across federated services",
'query GetProductWithReviews($productId: ID!) {\n product(id: $productId) {\n id\n name\n price\n category\n reviews {\n rating\n comment\n author {\n name\n avatarUrl\n }\n }\n averageRating\n inventory {\n warehouse\n quantity\n }\n }\n}',
"federation_query"),
]
for q, gql, cat in templates:
pairs.append({"source": q, "target": gql, "dialect": "graphql",
"category": cat, "reward_signals": _gql_reward()})
return pairs
def _build_graphql_mistakes() -> List[Dict]:
"""Common GraphQL mistakes."""
pairs = []
mistakes = [
("Query user without required argument",
'{ user { name email } }',
'query GetUser($userId: ID!) {\n user(id: $userId) {\n name\n email\n }\n}',
"Missing required arguments β€” user needs id parameter"),
("N+1 query pattern",
'query { users { name posts { comments { author { name } } } } }',
'query GetUsersWithPosts {\n users(first: 20) {\n edges {\n node {\n name\n posts(first: 10) {\n edges {\n node {\n title\n commentCount\n }\n }\n }\n }\n }\n }\n}',
"Deeply nested queries cause N+1 β€” limit depth, use pagination"),
("Mutation without error handling",
'mutation { createUser(name: "Alice") { id } }',
'mutation CreateUser($input: CreateUserInput!) {\n createUser(input: $input) {\n user {\n id\n name\n }\n errors {\n field\n message\n }\n }\n}',
"Mutations should return union of result + errors, use input types"),
]
for q, bad, good, explanation in mistakes:
pairs.append({"source": q, "target": good, "dialect": "graphql",
"category": "mistake_correction",
"reward_signals": _gql_reward()})
pairs.append({"source": q, "target": bad, "dialect": "graphql",
"category": "common_mistake", "mistake_explanation": explanation,
"reward_signals": _gql_reward(schema=0.3, efficiency=0.2)})
return pairs
# ── Main ─────────────────────────────────────────────────────────────────────
def build_all() -> List[Dict]:
builders = [
_build_python_data_processing,
_build_python_async,
_build_python_design_patterns,
_build_python_algorithms,
_build_python_testing,
_build_python_mistakes,
_build_graphql_queries,
_build_graphql_mutations,
_build_graphql_subscriptions,
_build_graphql_fragments,
_build_graphql_federation,
_build_graphql_mistakes,
]
all_pairs = []
for builder in builders:
pairs = builder()
cat = pairs[0]["category"] if pairs else "unknown"
print(f" {builder.__name__}: {len(pairs)} pairs ({cat})")
all_pairs.extend(pairs)
return all_pairs
def main():
print("=== Building Python + GraphQL Generation Datasets ===\n")
pairs = build_all()
py_count = sum(1 for p in pairs if p["dialect"] == "python")
gql_count = sum(1 for p in pairs if p["dialect"] == "graphql")
print(f"\nTotal: {len(pairs)} pairs (Python: {py_count}, GraphQL: {gql_count})")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
out_path = RESULTS_DIR / "python_graphql_dataset.json"
with open(out_path, "w") as f:
json.dump(pairs, f, indent=2)
print(f"Saved β†’ {out_path}")
print("Run scripts/combine_and_push_datasets.py to merge into SOTA data")
return pairs
if __name__ == "__main__":
main()