github-actions[bot] commited on
Commit
f1eedd1
·
1 Parent(s): 52419fe

Auto-sync from demo at Fri Nov 7 08:27:33 UTC 2025

Browse files
app.py CHANGED
@@ -8,6 +8,7 @@ import gradio as gr
8
  import pandas as pd
9
  from dotenv import load_dotenv
10
 
 
11
  from graphgen.graphgen import GraphGen
12
  from graphgen.models import OpenAIClient, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
@@ -97,26 +98,61 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
97
  "unit_sampling": params.ece_unit_sampling,
98
  }
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  config = {
101
  "if_trainee_model": params.if_trainee_model,
102
  "read": {"input_file": params.upload_file},
103
- "split": {
104
- "chunk_size": params.chunk_size,
105
- "chunk_overlap": params.chunk_overlap,
106
- },
107
- "search": {"enabled": False},
108
- "quiz_and_judge": {
109
- "enabled": params.if_trainee_model,
110
- "quiz_samples": params.quiz_samples,
111
- },
112
- "partition": {
113
- "method": params.partition_method,
114
- "method_params": partition_params,
115
- },
116
- "generate": {
117
- "mode": params.mode,
118
- "data_format": params.data_format,
119
- },
120
  }
121
 
122
  env = {
@@ -145,20 +181,12 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
145
  # Initialize GraphGen
146
  graph_gen = init_graph_gen(config, env)
147
  graph_gen.clear()
148
-
149
  graph_gen.progress_bar = progress
150
 
151
  try:
152
- # Process the data
153
- graph_gen.insert(read_config=config["read"], split_config=config["split"])
154
-
155
- if config["if_trainee_model"]:
156
- graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
157
-
158
- graph_gen.generate(
159
- partition_config=config["partition"],
160
- generate_config=config["generate"],
161
- )
162
 
163
  # Save output
164
  output_data = graph_gen.qa_storage.data
 
8
  import pandas as pd
9
  from dotenv import load_dotenv
10
 
11
+ from graphgen.engine import Context, Engine, collect_ops
12
  from graphgen.graphgen import GraphGen
13
  from graphgen.models import OpenAIClient, Tokenizer
14
  from graphgen.models.llm.limitter import RPM, TPM
 
98
  "unit_sampling": params.ece_unit_sampling,
99
  }
100
 
101
+ pipeline = [
102
+ {
103
+ "name": "read",
104
+ "params": {
105
+ "input_file": params.upload_file,
106
+ "chunk_size": params.chunk_size,
107
+ "chunk_overlap": params.chunk_overlap,
108
+ },
109
+ },
110
+ {
111
+ "name": "build_kg",
112
+ },
113
+ ]
114
+
115
+ if params.if_trainee_model:
116
+ pipeline.append(
117
+ {
118
+ "name": "quiz_and_judge",
119
+ "params": {"quiz_samples": params.quiz_samples, "re_judge": True},
120
+ }
121
+ )
122
+ pipeline.append(
123
+ {
124
+ "name": "partition",
125
+ "deps": ["quiz_and_judge"],
126
+ "params": {
127
+ "method": params.partition_method,
128
+ "method_params": partition_params,
129
+ },
130
+ }
131
+ )
132
+ else:
133
+ pipeline.append(
134
+ {
135
+ "name": "partition",
136
+ "params": {
137
+ "method": params.partition_method,
138
+ "method_params": partition_params,
139
+ },
140
+ }
141
+ )
142
+ pipeline.append(
143
+ {
144
+ "name": "generate",
145
+ "params": {
146
+ "method": params.mode,
147
+ "data_format": params.data_format,
148
+ },
149
+ }
150
+ )
151
+
152
  config = {
153
  "if_trainee_model": params.if_trainee_model,
154
  "read": {"input_file": params.upload_file},
155
+ "pipeline": pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  }
157
 
158
  env = {
 
181
  # Initialize GraphGen
182
  graph_gen = init_graph_gen(config, env)
183
  graph_gen.clear()
 
184
  graph_gen.progress_bar = progress
185
 
186
  try:
187
+ ctx = Context(config=config, graph_gen=graph_gen)
188
+ ops = collect_ops(config, graph_gen)
189
+ Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx)
 
 
 
 
 
 
 
190
 
191
  # Save output
192
  output_data = graph_gen.qa_storage.data
graphgen/configs/aggregated_config.yaml CHANGED
@@ -1,22 +1,28 @@
1
- read:
2
- input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
3
- split:
4
- chunk_size: 1024 # chunk size for text splitting
5
- chunk_overlap: 100 # chunk overlap for text splitting
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
- enabled: true
11
- quiz_samples: 2 # number of quiz samples to generate
12
- re_judge: false # whether to re-judge the existing quiz samples
13
- partition: # graph partition configuration
14
- method: ece # ece is a custom partition method based on comprehension loss
15
- method_params:
16
- max_units_per_community: 20 # max nodes and edges per community
17
- min_units_per_community: 5 # min nodes and edges per community
18
- max_tokens_per_community: 10240 # max tokens per community
19
- unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
20
- generate:
21
- mode: aggregated # atomic, aggregated, multi_hop, cot, vqa
22
- data_format: ChatML # Alpaca, Sharegpt, ChatML
 
 
 
 
 
 
 
1
+ pipeline:
2
+ - name: read
3
+ params:
4
+ input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5
+ chunk_size: 1024 # chunk size for text splitting
6
+ chunk_overlap: 100 # chunk overlap for text splitting
7
+
8
+ - name: build_kg
9
+
10
+ - name: quiz_and_judge
11
+ params:
12
+ quiz_samples: 2 # number of quiz samples to generate
13
+ re_judge: false # whether to re-judge the existing quiz samples
14
+
15
+ - name: partition
16
+ deps: [quiz_and_judge] # ece depends on quiz_and_judge steps
17
+ params:
18
+ method: ece # ece is a custom partition method based on comprehension loss
19
+ method_params:
20
+ max_units_per_community: 20 # max nodes and edges per community
21
+ min_units_per_community: 5 # min nodes and edges per community
22
+ max_tokens_per_community: 10240 # max tokens per community
23
+ unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
24
+
25
+ - name: generate
26
+ params:
27
+ method: aggregated # atomic, aggregated, multi_hop, cot, vqa
28
+ data_format: ChatML # Alpaca, Sharegpt, ChatML
graphgen/configs/atomic_config.yaml CHANGED
@@ -1,19 +1,18 @@
1
- read:
2
- input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
3
- split:
4
- chunk_size: 1024 # chunk size for text splitting
5
- chunk_overlap: 100 # chunk overlap for text splitting
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
- enabled: true
11
- quiz_samples: 2 # number of quiz samples to generate
12
- re_judge: false # whether to re-judge the existing quiz samples
13
- partition: # graph partition configuration
14
- method: dfs # partition method, support: dfs, bfs, ece, leiden
15
- method_params:
16
- max_units_per_community: 1 # atomic partition, one node or edge per community
17
- generate:
18
- mode: atomic # atomic, aggregated, multi_hop, cot, vqa
19
- data_format: Alpaca # Alpaca, Sharegpt, ChatML
 
1
+ pipeline:
2
+ - name: read
3
+ params:
4
+ input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
5
+ chunk_size: 1024 # chunk size for text splitting
6
+ chunk_overlap: 100 # chunk overlap for text splitting
7
+
8
+ - name: build_kg
9
+
10
+ - name: partition
11
+ params:
12
+ method: dfs # partition method, support: dfs, bfs, ece, leiden
13
+ method_params:
14
+ max_units_per_community: 1 # atomic partition, one node or edge per community
15
+ - name: generate
16
+ params:
17
+ method: atomic # atomic, aggregated, multi_hop, cot, vqa
18
+ data_format: Alpaca # Alpaca, Sharegpt, ChatML
 
graphgen/configs/cot_config.yaml CHANGED
@@ -1,19 +1,21 @@
1
- read:
2
- input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
3
- split:
4
- chunk_size: 1024 # chunk size for text splitting
5
- chunk_overlap: 100 # chunk overlap for text splitting
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
- enabled: false
11
- partition: # graph partition configuration
12
- method: leiden # leiden is a partitioner detection algorithm
13
- method_params:
14
- max_size: 20 # Maximum size of communities
15
- use_lcc: false # whether to use the largest connected component
16
- random_seed: 42 # random seed for partitioning
17
- generate:
18
- mode: cot # atomic, aggregated, multi_hop, cot, vqa
19
- data_format: Sharegpt # Alpaca, Sharegpt, ChatML
 
 
 
1
+ pipeline:
2
+ - name: read
3
+ params:
4
+ input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5
+ chunk_size: 1024 # chunk size for text splitting
6
+ chunk_overlap: 100 # chunk overlap for text splitting
7
+
8
+ - name: build_kg
9
+
10
+ - name: partition
11
+ params:
12
+ method: leiden # leiden is a partitioner detection algorithm
13
+ method_params:
14
+ max_size: 20 # Maximum size of communities
15
+ use_lcc: false # whether to use the largest connected component
16
+ random_seed: 42 # random seed for partitioning
17
+
18
+ - name: generate
19
+ params:
20
+ method: cot # atomic, aggregated, multi_hop, cot, vqa
21
+ data_format: Sharegpt # Alpaca, Sharegpt, ChatML
graphgen/configs/multi_hop_config.yaml CHANGED
@@ -1,22 +1,22 @@
1
- read:
2
- input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
3
- split:
4
- chunk_size: 1024 # chunk size for text splitting
5
- chunk_overlap: 100 # chunk overlap for text splitting
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
- enabled: false
11
- quiz_samples: 2 # number of quiz samples to generate
12
- re_judge: false # whether to re-judge the existing quiz samples
13
- partition: # graph partition configuration
14
- method: ece # ece is a custom partition method based on comprehension loss
15
- method_params:
16
- max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
17
- min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
18
- max_tokens_per_community: 10240 # max tokens per community
19
- unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss
20
- generate:
21
- mode: multi_hop # atomic, aggregated, multi_hop, cot, vqa
22
- data_format: ChatML # Alpaca, Sharegpt, ChatML
 
1
+ pipeline:
2
+ - name: read
3
+ params:
4
+ input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5
+ chunk_size: 1024 # chunk size for text splitting
6
+ chunk_overlap: 100 # chunk overlap for text splitting
7
+
8
+ - name: build_kg
9
+
10
+ - name: partition
11
+ params:
12
+ method: ece # ece is a custom partition method based on comprehension loss
13
+ method_params:
14
+ max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
15
+ min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
16
+ max_tokens_per_community: 10240 # max tokens per community
17
+ unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss
18
+
19
+ - name: generate
20
+ params:
21
+ method: multi_hop # atomic, aggregated, multi_hop, cot, vqa
22
+ data_format: ChatML # Alpaca, Sharegpt, ChatML
graphgen/configs/vqa_config.yaml CHANGED
@@ -1,18 +1,20 @@
1
- read:
2
- input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
3
- split:
4
- chunk_size: 1024 # chunk size for text splitting
5
- chunk_overlap: 100 # chunk overlap for text splitting
6
- search: # web search configuration
7
- enabled: false # whether to enable web search
8
- search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
- enabled: false
11
- partition: # graph partition configuration
12
- method: anchor_bfs # partition method
13
- method_params:
14
- anchor_type: image # node type to select anchor nodes
15
- max_units_per_community: 10 # atomic partition, one node or edge per community
16
- generate:
17
- mode: vqa # atomic, aggregated, multi_hop, cot, vqa
18
- data_format: ChatML # Alpaca, Sharegpt, ChatML
 
 
 
1
+ pipeline:
2
+ - name: read
3
+ params:
4
+ input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5
+ chunk_size: 1024 # chunk size for text splitting
6
+ chunk_overlap: 100 # chunk overlap for text splitting
7
+
8
+ - name: build_kg
9
+
10
+ - name: partition
11
+ params:
12
+ method: anchor_bfs # partition method
13
+ method_params:
14
+ anchor_type: image # node type to select anchor nodes
15
+ max_units_per_community: 10 # atomic partition, one node or edge per community
16
+
17
+ - name: generate
18
+ params:
19
+ method: vqa # atomic, aggregated, multi_hop, cot, vqa
20
+ data_format: ChatML # Alpaca, Sharegpt, ChatML
graphgen/engine.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ orchestration engine for GraphGen
3
+ """
4
+
5
+ import threading
6
+ import traceback
7
+ from functools import wraps
8
+ from typing import Any, Callable, List
9
+
10
+
11
+ class Context(dict):
12
+ _lock = threading.Lock()
13
+
14
+ def set(self, k, v):
15
+ with self._lock:
16
+ self[k] = v
17
+
18
+ def get(self, k, default=None):
19
+ with self._lock:
20
+ return super().get(k, default)
21
+
22
+
23
+ class OpNode:
24
+ def __init__(
25
+ self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
26
+ ):
27
+ self.name, self.deps, self.func = name, deps, func
28
+
29
+
30
+ def op(name: str, deps=None):
31
+ deps = deps or []
32
+
33
+ def decorator(func):
34
+ @wraps(func)
35
+ def _wrapper(*args, **kwargs):
36
+ return func(*args, **kwargs)
37
+
38
+ _wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx))
39
+ return _wrapper
40
+
41
+ return decorator
42
+
43
+
44
+ class Engine:
45
+ def __init__(self, max_workers: int = 4):
46
+ self.max_workers = max_workers
47
+
48
+ def run(self, ops: List[OpNode], ctx: Context):
49
+ name2op = {operation.name: operation for operation in ops}
50
+
51
+ # topological sort
52
+ graph = {n: set(name2op[n].deps) for n in name2op}
53
+ topo = []
54
+ q = [n for n, d in graph.items() if not d]
55
+ while q:
56
+ cur = q.pop(0)
57
+ topo.append(cur)
58
+ for child in [c for c, d in graph.items() if cur in d]:
59
+ graph[child].remove(cur)
60
+ if not graph[child]:
61
+ q.append(child)
62
+
63
+ if len(topo) != len(ops):
64
+ raise ValueError(
65
+ "Cyclic dependencies detected among operations."
66
+ "Please check your configuration."
67
+ )
68
+
69
+ # semaphore for max_workers
70
+ sem = threading.Semaphore(self.max_workers)
71
+ done = {n: threading.Event() for n in name2op}
72
+ exc = {}
73
+
74
+ def _exec(n: str):
75
+ with sem:
76
+ for d in name2op[n].deps:
77
+ done[d].wait()
78
+ if any(d in exc for d in name2op[n].deps):
79
+ exc[n] = Exception("Skipped due to failed dependencies")
80
+ done[n].set()
81
+ return
82
+ try:
83
+ name2op[n].func(name2op[n], ctx)
84
+ except Exception: # pylint: disable=broad-except
85
+ exc[n] = traceback.format_exc()
86
+ done[n].set()
87
+
88
+ ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
89
+ for t in ts:
90
+ t.start()
91
+ for t in ts:
92
+ t.join()
93
+ if exc:
94
+ raise RuntimeError(
95
+ "Some operations failed:\n"
96
+ + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
97
+ )
98
+
99
+
100
+ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
101
+ """
102
+ build operation nodes from yaml config
103
+ :param config
104
+ :param graph_gen
105
+ """
106
+ ops: List[OpNode] = []
107
+ for stage in config["pipeline"]:
108
+ name = stage["name"]
109
+ method = getattr(graph_gen, name)
110
+ op_node = method.op_node
111
+
112
+ # if there are runtime dependencies, override them
113
+ runtime_deps = stage.get("deps", op_node.deps)
114
+ op_node.deps = runtime_deps
115
+
116
+ if "params" in stage:
117
+ op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {}))
118
+ else:
119
+ op_node.func = lambda self, ctx, m=method: m()
120
+ ops.append(op_node)
121
+ return ops
graphgen/evaluate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  """Evaluate the quality of the generated text using various metrics"""
2
 
3
  import argparse
 
1
+ # TODO: this module needs refactoring to merge into GraphGen framework
2
  """Evaluate the quality of the generated text using various metrics"""
3
 
4
  import argparse
graphgen/graphgen.py CHANGED
@@ -1,16 +1,16 @@
1
- import asyncio
2
  import os
3
  import time
4
- from typing import Dict, cast
5
 
6
  import gradio as gr
7
 
8
  from graphgen.bases import BaseLLMWrapper
9
- from graphgen.bases.base_storage import StorageNameSpace
10
  from graphgen.bases.datatypes import Chunk
 
11
  from graphgen.models import (
12
  JsonKVStorage,
13
  JsonListStorage,
 
14
  NetworkXStorage,
15
  OpenAIClient,
16
  Tokenizer,
@@ -54,6 +54,10 @@ class GraphGen:
54
  )
55
  self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
56
 
 
 
 
 
57
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
58
  self.working_dir, namespace="full_docs"
59
  )
@@ -69,6 +73,9 @@ class GraphGen:
69
  self.rephrase_storage: JsonKVStorage = JsonKVStorage(
70
  self.working_dir, namespace="rephrase"
71
  )
 
 
 
72
  self.qa_storage: JsonListStorage = JsonListStorage(
73
  os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
74
  namespace="qa",
@@ -77,12 +84,12 @@ class GraphGen:
77
  # webui
78
  self.progress_bar: gr.Progress = progress_bar
79
 
 
80
  @async_to_sync_method
81
- async def insert(self, read_config: Dict, split_config: Dict):
82
  """
83
- insert chunks into the graph
84
  """
85
- # Step 1: Read files
86
  data = read_files(read_config["input_file"], self.working_dir)
87
  if len(data) == 0:
88
  logger.warning("No data to process")
@@ -102,8 +109,8 @@ class GraphGen:
102
 
103
  inserting_chunks = await chunk_documents(
104
  new_docs,
105
- split_config["chunk_size"],
106
- split_config["chunk_overlap"],
107
  self.tokenizer_instance,
108
  self.progress_bar,
109
  )
@@ -119,9 +126,25 @@ class GraphGen:
119
  logger.warning("All chunks are already in the storage")
120
  return
121
 
122
- logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
 
123
  await self.chunks_storage.upsert(inserting_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
125
  _add_entities_and_relations = await build_kg(
126
  llm_client=self.synthesizer_llm_client,
127
  kg_instance=self.graph_storage,
@@ -132,22 +155,13 @@ class GraphGen:
132
  logger.warning("No entities or relations extracted from text chunks")
133
  return
134
 
135
- await self._insert_done()
136
- return _add_entities_and_relations
 
137
 
138
- async def _insert_done(self):
139
- tasks = []
140
- for storage_instance in [
141
- self.full_docs_storage,
142
- self.chunks_storage,
143
- self.graph_storage,
144
- self.search_storage,
145
- ]:
146
- if storage_instance is None:
147
- continue
148
- tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
149
- await asyncio.gather(*tasks)
150
 
 
151
  @async_to_sync_method
152
  async def search(self, search_config: Dict):
153
  logger.info(
@@ -181,15 +195,15 @@ class GraphGen:
181
  ]
182
  )
183
  # TODO: fix insert after search
184
- await self.insert()
185
 
 
186
  @async_to_sync_method
187
  async def quiz_and_judge(self, quiz_and_judge_config: Dict):
188
- if quiz_and_judge_config is None or not quiz_and_judge_config.get(
189
- "enabled", False
190
- ):
191
- logger.warning("Quiz and Judge is not used in this pipeline.")
192
- return
193
  max_samples = quiz_and_judge_config["quiz_samples"]
194
  await quiz(
195
  self.synthesizer_llm_client,
@@ -222,15 +236,26 @@ class GraphGen:
222
  logger.info("Restarting synthesizer LLM client.")
223
  self.synthesizer_llm_client.restart()
224
 
 
225
  @async_to_sync_method
226
- async def generate(self, partition_config: Dict, generate_config: Dict):
227
- # Step 1: partition the graph
228
  batches = await partition_kg(
229
  self.graph_storage,
230
  self.chunks_storage,
231
  self.tokenizer_instance,
232
  partition_config,
233
  )
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  # Step 2: generate QA pairs
236
  results = await generate_qas(
@@ -258,3 +283,6 @@ class GraphGen:
258
  await self.qa_storage.drop()
259
 
260
  logger.info("All caches are cleared")
 
 
 
 
 
1
  import os
2
  import time
3
+ from typing import Dict
4
 
5
  import gradio as gr
6
 
7
  from graphgen.bases import BaseLLMWrapper
 
8
  from graphgen.bases.datatypes import Chunk
9
+ from graphgen.engine import op
10
  from graphgen.models import (
11
  JsonKVStorage,
12
  JsonListStorage,
13
+ MetaJsonKVStorage,
14
  NetworkXStorage,
15
  OpenAIClient,
16
  Tokenizer,
 
54
  )
55
  self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
56
 
57
+ self.meta_storage: MetaJsonKVStorage = MetaJsonKVStorage(
58
+ self.working_dir, namespace="_meta"
59
+ )
60
+
61
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
62
  self.working_dir, namespace="full_docs"
63
  )
 
73
  self.rephrase_storage: JsonKVStorage = JsonKVStorage(
74
  self.working_dir, namespace="rephrase"
75
  )
76
+ self.partition_storage: JsonListStorage = JsonListStorage(
77
+ self.working_dir, namespace="partition"
78
+ )
79
  self.qa_storage: JsonListStorage = JsonListStorage(
80
  os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
81
  namespace="qa",
 
84
  # webui
85
  self.progress_bar: gr.Progress = progress_bar
86
 
87
+ @op("read", deps=[])
88
  @async_to_sync_method
89
+ async def read(self, read_config: Dict):
90
  """
91
+ read files from input sources
92
  """
 
93
  data = read_files(read_config["input_file"], self.working_dir)
94
  if len(data) == 0:
95
  logger.warning("No data to process")
 
109
 
110
  inserting_chunks = await chunk_documents(
111
  new_docs,
112
+ read_config["chunk_size"],
113
+ read_config["chunk_overlap"],
114
  self.tokenizer_instance,
115
  self.progress_bar,
116
  )
 
126
  logger.warning("All chunks are already in the storage")
127
  return
128
 
129
+ await self.full_docs_storage.upsert(new_docs)
130
+ await self.full_docs_storage.index_done_callback()
131
  await self.chunks_storage.upsert(inserting_chunks)
132
+ await self.chunks_storage.index_done_callback()
133
+
134
+ @op("build_kg", deps=["read"])
135
+ @async_to_sync_method
136
+ async def build_kg(self):
137
+ """
138
+ build knowledge graph from text chunks
139
+ """
140
+ # Step 1: get new chunks according to meta and chunks storage
141
+ inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage)
142
+ if len(inserting_chunks) == 0:
143
+ logger.warning("All chunks are already in the storage")
144
+ return
145
 
146
+ logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
147
+ # Step 2: build knowledge graph from new chunks
148
  _add_entities_and_relations = await build_kg(
149
  llm_client=self.synthesizer_llm_client,
150
  kg_instance=self.graph_storage,
 
155
  logger.warning("No entities or relations extracted from text chunks")
156
  return
157
 
158
+ # Step 3: mark meta
159
+ await self.meta_storage.mark_done(self.chunks_storage)
160
+ await self.meta_storage.index_done_callback()
161
 
162
+ return _add_entities_and_relations
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ @op("search", deps=["read"])
165
  @async_to_sync_method
166
  async def search(self, search_config: Dict):
167
  logger.info(
 
195
  ]
196
  )
197
  # TODO: fix insert after search
198
+ # await self.insert()
199
 
200
+ @op("quiz_and_judge", deps=["build_kg"])
201
  @async_to_sync_method
202
  async def quiz_and_judge(self, quiz_and_judge_config: Dict):
203
+ logger.warning(
204
+ "Quiz and Judge operation needs trainee LLM client."
205
+ " Make sure to provide one."
206
+ )
 
207
  max_samples = quiz_and_judge_config["quiz_samples"]
208
  await quiz(
209
  self.synthesizer_llm_client,
 
236
  logger.info("Restarting synthesizer LLM client.")
237
  self.synthesizer_llm_client.restart()
238
 
239
+ @op("partition", deps=["build_kg"])
240
  @async_to_sync_method
241
+ async def partition(self, partition_config: Dict):
 
242
  batches = await partition_kg(
243
  self.graph_storage,
244
  self.chunks_storage,
245
  self.tokenizer_instance,
246
  partition_config,
247
  )
248
+ await self.partition_storage.upsert(batches)
249
+ return batches
250
+
251
+ @op("generate", deps=["partition"])
252
+ @async_to_sync_method
253
+ async def generate(self, generate_config: Dict):
254
+
255
+ batches = self.partition_storage.data
256
+ if not batches:
257
+ logger.warning("No partitions found for QA generation")
258
+ return
259
 
260
  # Step 2: generate QA pairs
261
  results = await generate_qas(
 
283
  await self.qa_storage.drop()
284
 
285
  logger.info("All caches are cleared")
286
+
287
+ # TODO: add data filtering step here in the future
288
+ # graph_gen.filter(filter_config=config["filter"])
graphgen/models/__init__.py CHANGED
@@ -30,5 +30,5 @@ from .search.kg.wiki_search import WikiSearch
30
  from .search.web.bing_search import BingSearch
31
  from .search.web.google_search import GoogleSearch
32
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
33
- from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage
34
  from .tokenizer import Tokenizer
 
30
  from .search.web.bing_search import BingSearch
31
  from .search.web.google_search import GoogleSearch
32
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
33
+ from .storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage, NetworkXStorage
34
  from .tokenizer import Tokenizer
graphgen/models/storage/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .json_storage import JsonKVStorage, JsonListStorage
2
  from .networkx_storage import NetworkXStorage
 
1
+ from .json_storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage
2
  from .networkx_storage import NetworkXStorage
graphgen/models/storage/json_storage.py CHANGED
@@ -44,11 +44,13 @@ class JsonKVStorage(BaseKVStorage):
44
 
45
  async def upsert(self, data: dict):
46
  left_data = {k: v for k, v in data.items() if k not in self._data}
47
- self._data.update(left_data)
 
48
  return left_data
49
 
50
  async def drop(self):
51
- self._data = {}
 
52
 
53
 
54
  @dataclass
@@ -87,3 +89,23 @@ class JsonListStorage(BaseListStorage):
87
 
88
  async def drop(self):
89
  self._data = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  async def upsert(self, data: dict):
46
  left_data = {k: v for k, v in data.items() if k not in self._data}
47
+ if left_data:
48
+ self._data.update(left_data)
49
  return left_data
50
 
51
  async def drop(self):
52
+ if self._data:
53
+ self._data.clear()
54
 
55
 
56
  @dataclass
 
89
 
90
  async def drop(self):
91
  self._data = []
92
+
93
+
94
+ @dataclass
95
+ class MetaJsonKVStorage(JsonKVStorage):
96
+ def __post_init__(self):
97
+ self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
98
+ self._data = load_json(self._file_name) or {}
99
+ logger.info("Load KV %s with %d data", self.namespace, len(self._data))
100
+
101
+ async def get_new_data(self, storage_instance: "JsonKVStorage") -> dict:
102
+ new_data = {}
103
+ for k, v in storage_instance.data.items():
104
+ if k not in self._data:
105
+ new_data[k] = v
106
+ return new_data
107
+
108
+ async def mark_done(self, storage_instance: "JsonKVStorage"):
109
+ new_data = await self.get_new_data(storage_instance)
110
+ if new_data:
111
+ self._data.update(new_data)
graphgen/models/storage/networkx_storage.py CHANGED
@@ -75,7 +75,8 @@ class NetworkXStorage(BaseGraphStorage):
75
 
76
  def __post_init__(self):
77
  """
78
- 如果图文件存在,则加载图文件,否则创建一个新图
 
79
  """
80
  self._graphml_xml_file = os.path.join(
81
  self.working_dir, f"{self.namespace}.graphml"
 
75
 
76
  def __post_init__(self):
77
  """
78
+ Initialize the NetworkX graph storage by loading an existing graph from a GraphML file,
79
+ if it exists, or creating a new empty graph otherwise.
80
  """
81
  self._graphml_xml_file = os.path.join(
82
  self.working_dir, f"{self.namespace}.graphml"
graphgen/operators/generate/generate_qas.py CHANGED
@@ -29,21 +29,21 @@ async def generate_qas(
29
  :param progress_bar
30
  :return: QA pairs
31
  """
32
- mode = generation_config["mode"]
33
- logger.info("[Generation] mode: %s, batches: %d", mode, len(batches))
34
 
35
- if mode == "atomic":
36
  generator = AtomicGenerator(llm_client)
37
- elif mode == "aggregated":
38
  generator = AggregatedGenerator(llm_client)
39
- elif mode == "multi_hop":
40
  generator = MultiHopGenerator(llm_client)
41
- elif mode == "cot":
42
  generator = CoTGenerator(llm_client)
43
- elif mode in ["vqa"]:
44
  generator = VQAGenerator(llm_client)
45
  else:
46
- raise ValueError(f"Unsupported generation mode: {mode}")
47
 
48
  results = await run_concurrent(
49
  generator.generate,
 
29
  :param progress_bar
30
  :return: QA pairs
31
  """
32
+ method = generation_config["method"]
33
+ logger.info("[Generation] mode: %s, batches: %d", method, len(batches))
34
 
35
+ if method == "atomic":
36
  generator = AtomicGenerator(llm_client)
37
+ elif method == "aggregated":
38
  generator = AggregatedGenerator(llm_client)
39
+ elif method == "multi_hop":
40
  generator = MultiHopGenerator(llm_client)
41
+ elif method == "cot":
42
  generator = CoTGenerator(llm_client)
43
+ elif method in ["vqa"]:
44
  generator = VQAGenerator(llm_client)
45
  else:
46
+ raise ValueError(f"Unsupported generation mode: {method}")
47
 
48
  results = await run_concurrent(
49
  generator.generate,
graphgen/{generate.py → run.py} RENAMED
@@ -6,6 +6,7 @@ from importlib.resources import files
6
  import yaml
7
  from dotenv import load_dotenv
8
 
 
9
  from graphgen.graphgen import GraphGen
10
  from graphgen.utils import logger, set_logger
11
 
@@ -50,38 +51,29 @@ def main():
50
  with open(args.config_file, "r", encoding="utf-8") as f:
51
  config = yaml.load(f, Loader=yaml.FullLoader)
52
 
53
- mode = config["generate"]["mode"]
54
  unique_id = int(time.time())
55
 
56
  output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
57
  set_working_dir(output_path)
58
 
59
  set_logger(
60
- os.path.join(output_path, f"{unique_id}_{mode}.log"),
61
  if_stream=True,
62
  )
63
  logger.info(
64
  "GraphGen with unique ID %s logging to %s",
65
  unique_id,
66
- os.path.join(working_dir, f"{unique_id}_{mode}.log"),
67
  )
68
 
69
  graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
70
 
71
- graph_gen.insert(read_config=config["read"], split_config=config["split"])
 
 
72
 
73
- graph_gen.search(search_config=config["search"])
74
-
75
- if config.get("quiz_and_judge", {}).get("enabled"):
76
- graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
77
-
78
- # TODO: add data filtering step here in the future
79
- # graph_gen.filter(filter_config=config["filter"])
80
-
81
- graph_gen.generate(
82
- partition_config=config["partition"],
83
- generate_config=config["generate"],
84
- )
85
 
86
  save_config(os.path.join(output_path, "config.yaml"), config)
87
  logger.info("GraphGen completed successfully. Data saved to %s", output_path)
 
6
  import yaml
7
  from dotenv import load_dotenv
8
 
9
+ from graphgen.engine import Context, Engine, collect_ops
10
  from graphgen.graphgen import GraphGen
11
  from graphgen.utils import logger, set_logger
12
 
 
51
  with open(args.config_file, "r", encoding="utf-8") as f:
52
  config = yaml.load(f, Loader=yaml.FullLoader)
53
 
 
54
  unique_id = int(time.time())
55
 
56
  output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
57
  set_working_dir(output_path)
58
 
59
  set_logger(
60
+ os.path.join(output_path, f"{unique_id}.log"),
61
  if_stream=True,
62
  )
63
  logger.info(
64
  "GraphGen with unique ID %s logging to %s",
65
  unique_id,
66
+ os.path.join(working_dir, f"{unique_id}.log"),
67
  )
68
 
69
  graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
70
 
71
+ # share context between different steps
72
+ ctx = Context(config=config, graph_gen=graph_gen)
73
+ ops = collect_ops(config, graph_gen)
74
 
75
+ # run operations
76
+ Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx)
 
 
 
 
 
 
 
 
 
 
77
 
78
  save_config(os.path.join(output_path, "config.yaml"), config)
79
  logger.info("GraphGen completed successfully. Data saved to %s", output_path)
webui/app.py CHANGED
@@ -8,6 +8,7 @@ import gradio as gr
8
  import pandas as pd
9
  from dotenv import load_dotenv
10
 
 
11
  from graphgen.graphgen import GraphGen
12
  from graphgen.models import OpenAIClient, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
@@ -97,26 +98,61 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
97
  "unit_sampling": params.ece_unit_sampling,
98
  }
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  config = {
101
  "if_trainee_model": params.if_trainee_model,
102
  "read": {"input_file": params.upload_file},
103
- "split": {
104
- "chunk_size": params.chunk_size,
105
- "chunk_overlap": params.chunk_overlap,
106
- },
107
- "search": {"enabled": False},
108
- "quiz_and_judge": {
109
- "enabled": params.if_trainee_model,
110
- "quiz_samples": params.quiz_samples,
111
- },
112
- "partition": {
113
- "method": params.partition_method,
114
- "method_params": partition_params,
115
- },
116
- "generate": {
117
- "mode": params.mode,
118
- "data_format": params.data_format,
119
- },
120
  }
121
 
122
  env = {
@@ -145,20 +181,12 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
145
  # Initialize GraphGen
146
  graph_gen = init_graph_gen(config, env)
147
  graph_gen.clear()
148
-
149
  graph_gen.progress_bar = progress
150
 
151
  try:
152
- # Process the data
153
- graph_gen.insert(read_config=config["read"], split_config=config["split"])
154
-
155
- if config["if_trainee_model"]:
156
- graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
157
-
158
- graph_gen.generate(
159
- partition_config=config["partition"],
160
- generate_config=config["generate"],
161
- )
162
 
163
  # Save output
164
  output_data = graph_gen.qa_storage.data
 
8
  import pandas as pd
9
  from dotenv import load_dotenv
10
 
11
+ from graphgen.engine import Context, Engine, collect_ops
12
  from graphgen.graphgen import GraphGen
13
  from graphgen.models import OpenAIClient, Tokenizer
14
  from graphgen.models.llm.limitter import RPM, TPM
 
98
  "unit_sampling": params.ece_unit_sampling,
99
  }
100
 
101
+ pipeline = [
102
+ {
103
+ "name": "read",
104
+ "params": {
105
+ "input_file": params.upload_file,
106
+ "chunk_size": params.chunk_size,
107
+ "chunk_overlap": params.chunk_overlap,
108
+ },
109
+ },
110
+ {
111
+ "name": "build_kg",
112
+ },
113
+ ]
114
+
115
+ if params.if_trainee_model:
116
+ pipeline.append(
117
+ {
118
+ "name": "quiz_and_judge",
119
+ "params": {"quiz_samples": params.quiz_samples, "re_judge": True},
120
+ }
121
+ )
122
+ pipeline.append(
123
+ {
124
+ "name": "partition",
125
+ "deps": ["quiz_and_judge"],
126
+ "params": {
127
+ "method": params.partition_method,
128
+ "method_params": partition_params,
129
+ },
130
+ }
131
+ )
132
+ else:
133
+ pipeline.append(
134
+ {
135
+ "name": "partition",
136
+ "params": {
137
+ "method": params.partition_method,
138
+ "method_params": partition_params,
139
+ },
140
+ }
141
+ )
142
+ pipeline.append(
143
+ {
144
+ "name": "generate",
145
+ "params": {
146
+ "method": params.mode,
147
+ "data_format": params.data_format,
148
+ },
149
+ }
150
+ )
151
+
152
  config = {
153
  "if_trainee_model": params.if_trainee_model,
154
  "read": {"input_file": params.upload_file},
155
+ "pipeline": pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  }
157
 
158
  env = {
 
181
  # Initialize GraphGen
182
  graph_gen = init_graph_gen(config, env)
183
  graph_gen.clear()
 
184
  graph_gen.progress_bar = progress
185
 
186
  try:
187
+ ctx = Context(config=config, graph_gen=graph_gen)
188
+ ops = collect_ops(config, graph_gen)
189
+ Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx)
 
 
 
 
 
 
 
190
 
191
  # Save output
192
  output_data = graph_gen.qa_storage.data