github-actions[bot] commited on
Commit
817f16e
·
1 Parent(s): 3a3b216

Auto-sync from demo at Tue Sep 30 07:59:12 UTC 2025

Browse files
app.py CHANGED
@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
39
  set_logger(log_file, if_stream=True)
40
  os.environ.update({k: str(v) for k, v in env.items()})
41
 
42
- graph_gen = GraphGen(working_dir=working_dir, config=config)
43
- # Set up LLM clients
44
- graph_gen.synthesizer_llm_client = OpenAIClient(
45
  model_name=env.get("SYNTHESIZER_MODEL", ""),
46
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
47
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
48
  request_limit=True,
49
  rpm=RPM(env.get("RPM", 1000)),
50
  tpm=TPM(env.get("TPM", 50000)),
 
51
  )
52
-
53
- graph_gen.trainee_llm_client = OpenAIClient(
54
  model_name=env.get("TRAINEE_MODEL", ""),
55
  base_url=env.get("TRAINEE_BASE_URL", ""),
56
  api_key=env.get("TRAINEE_API_KEY", ""),
57
  request_limit=True,
58
  rpm=RPM(env.get("RPM", 1000)),
59
  tpm=TPM(env.get("TPM", 50000)),
 
60
  )
61
 
62
- graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
 
 
 
 
 
63
 
64
  return graph_gen
65
 
@@ -78,27 +83,32 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
78
  "chunk_size": params.chunk_size,
79
  "chunk_overlap": params.chunk_overlap,
80
  },
81
- "output_data_type": params.output_data_type,
82
- "output_data_format": params.output_data_format,
83
- "tokenizer": params.tokenizer,
84
  "search": {"enabled": False},
85
- "quiz_and_judge_strategy": {
86
  "enabled": params.if_trainee_model,
87
  "quiz_samples": params.quiz_samples,
88
  },
89
- "traverse_strategy": {
90
- "bidirectional": params.bidirectional,
91
- "expand_method": params.expand_method,
92
- "max_extra_edges": params.max_extra_edges,
93
- "max_tokens": params.max_tokens,
94
- "max_depth": params.max_depth,
95
- "edge_sampling": params.edge_sampling,
96
- "isolated_node_strategy": params.isolated_node_strategy,
97
- "loss_strategy": params.loss_strategy,
 
 
 
 
 
 
 
98
  },
99
  }
100
 
101
  env = {
 
102
  "SYNTHESIZER_BASE_URL": params.synthesizer_url,
103
  "SYNTHESIZER_MODEL": params.synthesizer_model,
104
  "TRAINEE_BASE_URL": params.trainee_url,
@@ -128,19 +138,18 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
128
 
129
  try:
130
  # Process the data
131
- graph_gen.insert()
132
 
133
  if config["if_trainee_model"]:
134
- # Generate quiz
135
- graph_gen.quiz()
136
-
137
- # Judge statements
138
- graph_gen.judge()
139
  else:
140
- graph_gen.traverse_strategy.edge_sampling = "random"
141
 
142
- # Traverse graph
143
- graph_gen.traverse()
 
 
144
 
145
  # Save output
146
  output_data = graph_gen.qa_storage.data
@@ -249,6 +258,9 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
249
  )
250
 
251
  with gr.Accordion(label=_("Model Config"), open=False):
 
 
 
252
  synthesizer_url = gr.Textbox(
253
  label="Synthesizer URL",
254
  value="https://api.siliconflow.cn/v1",
@@ -300,9 +312,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
300
  step=100,
301
  interactive=True,
302
  )
303
- tokenizer = gr.Textbox(
304
- label="Tokenizer", value="cl100k_base", interactive=True
305
- )
306
  output_data_type = gr.Radio(
307
  choices=["atomic", "multi_hop", "aggregated"],
308
  label="Output Data Type",
 
39
  set_logger(log_file, if_stream=True)
40
  os.environ.update({k: str(v) for k, v in env.items()})
41
 
42
+ tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
43
+ synthesizer_llm_client = OpenAIClient(
 
44
  model_name=env.get("SYNTHESIZER_MODEL", ""),
45
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
46
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
47
  request_limit=True,
48
  rpm=RPM(env.get("RPM", 1000)),
49
  tpm=TPM(env.get("TPM", 50000)),
50
+ tokenizer=tokenizer_instance,
51
  )
52
+ trainee_llm_client = OpenAIClient(
 
53
  model_name=env.get("TRAINEE_MODEL", ""),
54
  base_url=env.get("TRAINEE_BASE_URL", ""),
55
  api_key=env.get("TRAINEE_API_KEY", ""),
56
  request_limit=True,
57
  rpm=RPM(env.get("RPM", 1000)),
58
  tpm=TPM(env.get("TPM", 50000)),
59
+ tokenizer=tokenizer_instance,
60
  )
61
 
62
+ graph_gen = GraphGen(
63
+ working_dir=working_dir,
64
+ tokenizer_instance=tokenizer_instance,
65
+ synthesizer_llm_client=synthesizer_llm_client,
66
+ trainee_llm_client=trainee_llm_client,
67
+ )
68
 
69
  return graph_gen
70
 
 
83
  "chunk_size": params.chunk_size,
84
  "chunk_overlap": params.chunk_overlap,
85
  },
 
 
 
86
  "search": {"enabled": False},
87
+ "quiz_and_judge": {
88
  "enabled": params.if_trainee_model,
89
  "quiz_samples": params.quiz_samples,
90
  },
91
+ "partition": {
92
+ "method": "ece",
93
+ "method_params": {
94
+ "bidirectional": params.bidirectional,
95
+ "expand_method": params.expand_method,
96
+ "max_extra_edges": params.max_extra_edges,
97
+ "max_tokens": params.max_tokens,
98
+ "max_depth": params.max_depth,
99
+ "edge_sampling": params.edge_sampling,
100
+ "isolated_node_strategy": params.isolated_node_strategy,
101
+ "loss_strategy": params.loss_strategy,
102
+ },
103
+ },
104
+ "generate": {
105
+ "mode": params.output_data_type,
106
+ "data_format": params.output_data_format,
107
  },
108
  }
109
 
110
  env = {
111
+ "TOKENIZER_MODEL": params.tokenizer,
112
  "SYNTHESIZER_BASE_URL": params.synthesizer_url,
113
  "SYNTHESIZER_MODEL": params.synthesizer_model,
114
  "TRAINEE_BASE_URL": params.trainee_url,
 
138
 
139
  try:
140
  # Process the data
141
+ graph_gen.insert(read_config=config["read"], split_config=config["split"])
142
 
143
  if config["if_trainee_model"]:
144
+ # Quiz and Judge
145
+ graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
 
 
 
146
  else:
147
+ config["partition"]["method_params"]["edge_sampling"] = "random"
148
 
149
+ graph_gen.generate(
150
+ partition_config=config["partition"],
151
+ generate_config=config["generate"],
152
+ )
153
 
154
  # Save output
155
  output_data = graph_gen.qa_storage.data
 
258
  )
259
 
260
  with gr.Accordion(label=_("Model Config"), open=False):
261
+ tokenizer = gr.Textbox(
262
+ label="Tokenizer", value="cl100k_base", interactive=True
263
+ )
264
  synthesizer_url = gr.Textbox(
265
  label="Synthesizer URL",
266
  value="https://api.siliconflow.cn/v1",
 
312
  step=100,
313
  interactive=True,
314
  )
 
 
 
315
  output_data_type = gr.Radio(
316
  choices=["atomic", "multi_hop", "aggregated"],
317
  label="Output Data Type",
graphgen/configs/aggregated_config.yaml CHANGED
@@ -6,19 +6,21 @@ split:
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- output_data_type: aggregated # atomic, aggregated, multi_hop, cot
10
- output_data_format: ChatML # Alpaca, Sharegpt, ChatML
11
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12
- quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
13
  enabled: true
14
  quiz_samples: 2 # number of quiz samples to generate
15
  re_judge: false # whether to re-judge the existing quiz samples
16
- traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17
- bidirectional: true # whether to traverse the graph in both directions
18
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19
- expand_method: max_width # expand method, support: max_width, max_depth
20
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21
- max_depth: 5 # maximum depth for graph traversal
22
- max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
23
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
 
 
 
 
 
 
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
 
 
 
10
  enabled: true
11
  quiz_samples: 2 # number of quiz samples to generate
12
  re_judge: false # whether to re-judge the existing quiz samples
13
+ partition: # graph partition configuration
14
+ method: ece # ece is a custom partition method based on comprehension loss
15
+ method_params:
16
+ bidirectional: true # whether to traverse the graph in both directions
17
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18
+ expand_method: max_width # expand method, support: max_width, max_depth
19
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20
+ max_depth: 5 # maximum depth for graph traversal
21
+ max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
22
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24
+ generate:
25
+ mode: aggregated # atomic, aggregated, multi_hop, cot
26
+ data_format: ChatML # Alpaca, Sharegpt, ChatML
graphgen/configs/atomic_config.yaml CHANGED
@@ -6,19 +6,21 @@ split:
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- output_data_type: atomic # atomic, aggregated, multi_hop, cot
10
- output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
11
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12
- quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
13
  enabled: true
14
  quiz_samples: 2 # number of quiz samples to generate
15
  re_judge: false # whether to re-judge the existing quiz samples
16
- traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17
- bidirectional: true # whether to traverse the graph in both directions
18
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19
- expand_method: max_width # expand method, support: max_width, max_depth
20
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21
- max_depth: 3 # maximum depth for graph traversal
22
- max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
23
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
 
 
 
 
 
 
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
 
 
 
10
  enabled: true
11
  quiz_samples: 2 # number of quiz samples to generate
12
  re_judge: false # whether to re-judge the existing quiz samples
13
+ partition: # graph partition configuration
14
+ method: ece # ece is a custom partition method based on comprehension loss
15
+ method_params:
16
+ bidirectional: true # whether to traverse the graph in both directions
17
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18
+ expand_method: max_width # expand method, support: max_width, max_depth
19
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20
+ max_depth: 3 # maximum depth for graph traversal
21
+ max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
22
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24
+ generate:
25
+ mode: atomic # atomic, aggregated, multi_hop, cot
26
+ data_format: Alpaca # Alpaca, Sharegpt, ChatML
graphgen/configs/cot_config.yaml CHANGED
@@ -6,11 +6,14 @@ split:
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- output_data_type: cot # atomic, aggregated, multi_hop, cot
10
- output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
11
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12
- method_params:
13
- method: leiden
14
- max_size: 20 # Maximum size of communities
15
- use_lcc: false
16
- random_seed: 42
 
 
 
 
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: false
11
+ partition: # graph partition configuration
12
+ method: leiden # leiden is a community detection algorithm
13
+ method_params:
14
+ max_size: 20 # Maximum size of communities
15
+ use_lcc: false
16
+ random_seed: 42
17
+ generate:
18
+ mode: cot # atomic, aggregated, multi_hop, cot
19
+ data_format: Sharegpt # Alpaca, Sharegpt, ChatML
graphgen/configs/multi_hop_config.yaml CHANGED
@@ -6,19 +6,21 @@ split:
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
- output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
10
- output_data_format: ChatML # Alpaca, Sharegpt, ChatML
11
- tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12
- quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
13
  enabled: false
14
  quiz_samples: 2 # number of quiz samples to generate
15
  re_judge: false # whether to re-judge the existing quiz samples
16
- traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17
- bidirectional: true # whether to traverse the graph in both directions
18
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19
- expand_method: max_width # expand method, support: max_width, max_depth
20
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21
- max_depth: 1 # maximum depth for graph traversal
22
- max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
23
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
 
 
 
 
 
 
6
  search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
 
 
 
10
  enabled: false
11
  quiz_samples: 2 # number of quiz samples to generate
12
  re_judge: false # whether to re-judge the existing quiz samples
13
+ partition: # graph partition configuration
14
+ method: ece # ece is a custom partition method based on comprehension loss
15
+ method_params:
16
+ bidirectional: true # whether to traverse the graph in both directions
17
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18
+ expand_method: max_width # expand method, support: max_width, max_depth
19
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20
+ max_depth: 1 # maximum depth for graph traversal
21
+ max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
22
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24
+ generate:
25
+ mode: multi_hop # strategy for generating multi-hop QA pairs
26
+ data_format: ChatML # Alpaca, Sharegpt, ChatML
graphgen/generate.py CHANGED
@@ -6,8 +6,8 @@ from importlib.resources import files
6
  import yaml
7
  from dotenv import load_dotenv
8
 
9
- from .graphgen import GraphGen
10
- from .utils import logger, set_logger
11
 
12
  sys_path = os.path.abspath(os.path.dirname(__file__))
13
 
@@ -50,50 +50,51 @@ def main():
50
  with open(args.config_file, "r", encoding="utf-8") as f:
51
  config = yaml.load(f, Loader=yaml.FullLoader)
52
 
53
- output_data_type = config["output_data_type"]
54
  unique_id = int(time.time())
55
 
56
- output_path = os.path.join(
57
- working_dir, "data", "graphgen", f"{unique_id}_{output_data_type}"
58
- )
59
  set_working_dir(output_path)
60
 
61
  set_logger(
62
- os.path.join(output_path, f"{unique_id}.log"),
63
  if_stream=True,
64
  )
65
  logger.info(
66
  "GraphGen with unique ID %s logging to %s",
67
  unique_id,
68
- os.path.join(
69
- working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log"
70
- ),
71
  )
72
 
73
- graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
74
 
75
- graph_gen.insert()
76
 
77
- if config["search"]["enabled"]:
78
- graph_gen.search()
79
 
80
  # Use pipeline according to the output data type
81
- if output_data_type in ["atomic", "aggregated", "multi_hop"]:
82
- if "quiz_and_judge_strategy" in config and config[
83
- "quiz_and_judge_strategy"
84
- ].get("enabled", False):
85
- graph_gen.quiz()
86
- graph_gen.judge()
87
  else:
88
  logger.warning(
89
  "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
90
  )
91
- graph_gen.traverse_strategy.edge_sampling = "random"
92
- graph_gen.traverse()
93
- elif output_data_type == "cot":
94
- graph_gen.generate_reasoning(method_params=config["method_params"])
 
 
 
95
  else:
96
- raise ValueError(f"Unsupported output data type: {output_data_type}")
 
 
 
 
 
97
 
98
  save_config(os.path.join(output_path, "config.yaml"), config)
99
  logger.info("GraphGen completed successfully. Data saved to %s", output_path)
 
6
  import yaml
7
  from dotenv import load_dotenv
8
 
9
+ from graphgen.graphgen import GraphGen
10
+ from graphgen.utils import logger, set_logger
11
 
12
  sys_path = os.path.abspath(os.path.dirname(__file__))
13
 
 
50
  with open(args.config_file, "r", encoding="utf-8") as f:
51
  config = yaml.load(f, Loader=yaml.FullLoader)
52
 
53
+ mode = config["generate"]["mode"]
54
  unique_id = int(time.time())
55
 
56
+ output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
 
 
57
  set_working_dir(output_path)
58
 
59
  set_logger(
60
+ os.path.join(output_path, f"{unique_id}_{mode}.log"),
61
  if_stream=True,
62
  )
63
  logger.info(
64
  "GraphGen with unique ID %s logging to %s",
65
  unique_id,
66
+ os.path.join(working_dir, f"{unique_id}_{mode}.log"),
 
 
67
  )
68
 
69
+ graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
70
 
71
+ graph_gen.insert(read_config=config["read"], split_config=config["split"])
72
 
73
+ graph_gen.search(search_config=config["search"])
 
74
 
75
  # Use pipeline according to the output data type
76
+ if mode in ["atomic", "aggregated", "multi_hop"]:
77
+ logger.info("Generation mode set to '%s'. Start generation.", mode)
78
+ if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]:
79
+ graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
 
 
80
  else:
81
  logger.warning(
82
  "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
83
  )
84
+ assert (
85
+ config["partition"]["method"] == "ece"
86
+ and "ece_params" in config["partition"]
87
+ ), "Only ECE partition with edge sampling is supported."
88
+ config["partition"]["method_params"]["edge_sampling"] = "random"
89
+ elif mode == "cot":
90
+ logger.info("Generation mode set to 'cot'. Start generation.")
91
  else:
92
+ raise ValueError(f"Unsupported output data type: {mode}")
93
+
94
+ graph_gen.generate(
95
+ partition_config=config["partition"],
96
+ generate_config=config["generate"],
97
+ )
98
 
99
  save_config(os.path.join(output_path, "config.yaml"), config)
100
  logger.info("GraphGen completed successfully. Data saved to %s", output_path)
graphgen/graphgen.py CHANGED
@@ -1,7 +1,7 @@
1
  import asyncio
2
  import os
3
  import time
4
- from dataclasses import dataclass, field
5
  from typing import Dict, cast
6
 
7
  import gradio as gr
@@ -14,7 +14,6 @@ from graphgen.models import (
14
  NetworkXStorage,
15
  OpenAIClient,
16
  Tokenizer,
17
- TraverseStrategy,
18
  )
19
  from graphgen.operators import (
20
  chunk_documents,
@@ -42,46 +41,36 @@ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
42
  class GraphGen:
43
  unique_id: int = int(time.time())
44
  working_dir: str = os.path.join(sys_path, "cache")
45
- config: Dict = field(default_factory=dict)
46
 
47
  # llm
48
  tokenizer_instance: Tokenizer = None
49
  synthesizer_llm_client: OpenAIClient = None
50
  trainee_llm_client: OpenAIClient = None
51
 
52
- # search
53
- search_config: dict = field(
54
- default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
55
- )
56
-
57
- # traversal
58
- traverse_strategy: TraverseStrategy = None
59
-
60
  # webui
61
  progress_bar: gr.Progress = None
62
 
63
  def __post_init__(self):
64
- self.tokenizer_instance: Tokenizer = Tokenizer(
65
- model_name=self.config["tokenizer"]
66
  )
67
- self.synthesizer_llm_client: OpenAIClient = OpenAIClient(
68
- model_name=os.getenv("SYNTHESIZER_MODEL"),
69
- api_key=os.getenv("SYNTHESIZER_API_KEY"),
70
- base_url=os.getenv("SYNTHESIZER_BASE_URL"),
71
- tokenizer=self.tokenizer_instance,
 
 
 
 
72
  )
73
- self.trainee_llm_client: OpenAIClient = OpenAIClient(
 
74
  model_name=os.getenv("TRAINEE_MODEL"),
75
  api_key=os.getenv("TRAINEE_API_KEY"),
76
  base_url=os.getenv("TRAINEE_BASE_URL"),
77
  tokenizer=self.tokenizer_instance,
78
  )
79
- self.search_config = self.config["search"]
80
-
81
- if "traverse_strategy" in self.config:
82
- self.traverse_strategy = TraverseStrategy(
83
- **self.config["traverse_strategy"]
84
- )
85
 
86
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
87
  self.working_dir, namespace="full_docs"
@@ -99,24 +88,17 @@ class GraphGen:
99
  self.working_dir, namespace="rephrase"
100
  )
101
  self.qa_storage: JsonListStorage = JsonListStorage(
102
- os.path.join(
103
- self.working_dir,
104
- "data",
105
- "graphgen",
106
- f"{self.unique_id}_{self.config['output_data_type']}",
107
- ),
108
  namespace="qa",
109
  )
110
 
111
  @async_to_sync_method
112
- async def insert(self):
113
  """
114
  insert chunks into the graph
115
  """
116
- input_file = self.config["read"]["input_file"]
117
-
118
  # Step 1: Read files
119
- data = read_files(input_file)
120
  if len(data) == 0:
121
  logger.warning("No data to process")
122
  return
@@ -141,8 +123,8 @@ class GraphGen:
141
 
142
  inserting_chunks = await chunk_documents(
143
  new_docs,
144
- self.config["split"]["chunk_size"],
145
- self.config["split"]["chunk_overlap"],
146
  self.tokenizer_instance,
147
  self.progress_bar,
148
  )
@@ -178,6 +160,7 @@ class GraphGen:
178
  return
179
 
180
  await self._insert_done()
 
181
 
182
  async def _insert_done(self):
183
  tasks = []
@@ -193,14 +176,12 @@ class GraphGen:
193
  await asyncio.gather(*tasks)
194
 
195
  @async_to_sync_method
196
- async def search(self):
197
  logger.info(
198
- "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
199
  )
200
- if self.search_config["enabled"]:
201
- logger.info(
202
- "[Search] %s ...", ", ".join(self.search_config["search_types"])
203
- )
204
  all_nodes = await self.graph_storage.get_all_nodes()
205
  all_nodes_names = [node[0] for node in all_nodes]
206
  new_search_entities = await self.full_docs_storage.filter_keys(
@@ -210,7 +191,7 @@ class GraphGen:
210
  "[Search] Found %d entities to search", len(new_search_entities)
211
  )
212
  _add_search_data = await search_all(
213
- search_types=self.search_config["search_types"],
214
  search_entities=new_search_entities,
215
  )
216
  if _add_search_data:
@@ -230,78 +211,77 @@ class GraphGen:
230
  await self.insert()
231
 
232
  @async_to_sync_method
233
- async def quiz(self):
234
- max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
 
 
 
 
 
235
  await quiz(
236
  self.synthesizer_llm_client,
237
  self.graph_storage,
238
  self.rephrase_storage,
239
  max_samples,
240
  )
241
- await self.rephrase_storage.index_done_callback()
242
 
243
- @async_to_sync_method
244
- async def judge(self):
245
- re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
246
  _update_relations = await judge_statement(
247
  self.trainee_llm_client,
248
  self.graph_storage,
249
  self.rephrase_storage,
250
  re_judge,
251
  )
 
252
  await _update_relations.index_done_callback()
253
 
254
  @async_to_sync_method
255
- async def traverse(self):
256
- output_data_type = self.config["output_data_type"]
257
-
258
- if output_data_type == "atomic":
 
259
  results = await traverse_graph_for_atomic(
260
  self.synthesizer_llm_client,
261
  self.tokenizer_instance,
262
  self.graph_storage,
263
- self.traverse_strategy,
264
  self.text_chunks_storage,
265
  self.progress_bar,
266
  )
267
- elif output_data_type == "multi_hop":
268
  results = await traverse_graph_for_multi_hop(
269
  self.synthesizer_llm_client,
270
  self.tokenizer_instance,
271
  self.graph_storage,
272
- self.traverse_strategy,
273
  self.text_chunks_storage,
274
  self.progress_bar,
275
  )
276
- elif output_data_type == "aggregated":
277
  results = await traverse_graph_for_aggregated(
278
  self.synthesizer_llm_client,
279
  self.tokenizer_instance,
280
  self.graph_storage,
281
- self.traverse_strategy,
282
  self.text_chunks_storage,
283
  self.progress_bar,
284
  )
 
 
 
 
 
 
285
  else:
286
- raise ValueError(f"Unknown qa_form: {output_data_type}")
287
-
288
- results = format_generation_results(
289
- results, output_data_format=self.config["output_data_format"]
290
- )
291
-
292
- await self.qa_storage.upsert(results)
293
- await self.qa_storage.index_done_callback()
294
-
295
- @async_to_sync_method
296
- async def generate_reasoning(self, method_params):
297
- results = await generate_cot(
298
- self.graph_storage,
299
- self.synthesizer_llm_client,
300
- method_params=method_params,
301
- )
302
 
 
303
  results = format_generation_results(
304
- results, output_data_format=self.config["output_data_format"]
305
  )
306
 
307
  await self.qa_storage.upsert(results)
 
1
  import asyncio
2
  import os
3
  import time
4
+ from dataclasses import dataclass
5
  from typing import Dict, cast
6
 
7
  import gradio as gr
 
14
  NetworkXStorage,
15
  OpenAIClient,
16
  Tokenizer,
 
17
  )
18
  from graphgen.operators import (
19
  chunk_documents,
 
41
  class GraphGen:
42
  unique_id: int = int(time.time())
43
  working_dir: str = os.path.join(sys_path, "cache")
 
44
 
45
  # llm
46
  tokenizer_instance: Tokenizer = None
47
  synthesizer_llm_client: OpenAIClient = None
48
  trainee_llm_client: OpenAIClient = None
49
 
 
 
 
 
 
 
 
 
50
  # webui
51
  progress_bar: gr.Progress = None
52
 
53
  def __post_init__(self):
54
+ self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
55
+ model_name=os.getenv("TOKENIZER_MODEL")
56
  )
57
+
58
+ self.synthesizer_llm_client: OpenAIClient = (
59
+ self.synthesizer_llm_client
60
+ or OpenAIClient(
61
+ model_name=os.getenv("SYNTHESIZER_MODEL"),
62
+ api_key=os.getenv("SYNTHESIZER_API_KEY"),
63
+ base_url=os.getenv("SYNTHESIZER_BASE_URL"),
64
+ tokenizer=self.tokenizer_instance,
65
+ )
66
  )
67
+
68
+ self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
69
  model_name=os.getenv("TRAINEE_MODEL"),
70
  api_key=os.getenv("TRAINEE_API_KEY"),
71
  base_url=os.getenv("TRAINEE_BASE_URL"),
72
  tokenizer=self.tokenizer_instance,
73
  )
 
 
 
 
 
 
74
 
75
  self.full_docs_storage: JsonKVStorage = JsonKVStorage(
76
  self.working_dir, namespace="full_docs"
 
88
  self.working_dir, namespace="rephrase"
89
  )
90
  self.qa_storage: JsonListStorage = JsonListStorage(
91
+ os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
 
 
 
 
 
92
  namespace="qa",
93
  )
94
 
95
  @async_to_sync_method
96
+ async def insert(self, read_config: Dict, split_config: Dict):
97
  """
98
  insert chunks into the graph
99
  """
 
 
100
  # Step 1: Read files
101
+ data = read_files(read_config["input_file"])
102
  if len(data) == 0:
103
  logger.warning("No data to process")
104
  return
 
123
 
124
  inserting_chunks = await chunk_documents(
125
  new_docs,
126
+ split_config["chunk_size"],
127
+ split_config["chunk_overlap"],
128
  self.tokenizer_instance,
129
  self.progress_bar,
130
  )
 
160
  return
161
 
162
  await self._insert_done()
163
+ return _add_entities_and_relations
164
 
165
  async def _insert_done(self):
166
  tasks = []
 
176
  await asyncio.gather(*tasks)
177
 
178
  @async_to_sync_method
179
+ async def search(self, search_config: Dict):
180
  logger.info(
181
+ "Search is %s", "enabled" if search_config["enabled"] else "disabled"
182
  )
183
+ if search_config["enabled"]:
184
+ logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
 
 
185
  all_nodes = await self.graph_storage.get_all_nodes()
186
  all_nodes_names = [node[0] for node in all_nodes]
187
  new_search_entities = await self.full_docs_storage.filter_keys(
 
191
  "[Search] Found %d entities to search", len(new_search_entities)
192
  )
193
  _add_search_data = await search_all(
194
+ search_types=search_config["search_types"],
195
  search_entities=new_search_entities,
196
  )
197
  if _add_search_data:
 
211
  await self.insert()
212
 
213
  @async_to_sync_method
214
+ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
215
+ if quiz_and_judge_config is None or not quiz_and_judge_config.get(
216
+ "enabled", False
217
+ ):
218
+ logger.warning("Quiz and Judge is not used in this pipeline.")
219
+ return
220
+ max_samples = quiz_and_judge_config["quiz_samples"]
221
  await quiz(
222
  self.synthesizer_llm_client,
223
  self.graph_storage,
224
  self.rephrase_storage,
225
  max_samples,
226
  )
 
227
 
228
+ # TODO: assert trainee_llm_client is valid before judge
229
+ re_judge = quiz_and_judge_config["re_judge"]
 
230
  _update_relations = await judge_statement(
231
  self.trainee_llm_client,
232
  self.graph_storage,
233
  self.rephrase_storage,
234
  re_judge,
235
  )
236
+ await self.rephrase_storage.index_done_callback()
237
  await _update_relations.index_done_callback()
238
 
239
  @async_to_sync_method
240
+ async def generate(self, partition_config: Dict, generate_config: Dict):
241
+ # Step 1: partition the graph
242
+ # TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
243
+ mode = generate_config["mode"]
244
+ if mode == "atomic":
245
  results = await traverse_graph_for_atomic(
246
  self.synthesizer_llm_client,
247
  self.tokenizer_instance,
248
  self.graph_storage,
249
+ partition_config["method_params"],
250
  self.text_chunks_storage,
251
  self.progress_bar,
252
  )
253
+ elif mode == "multi_hop":
254
  results = await traverse_graph_for_multi_hop(
255
  self.synthesizer_llm_client,
256
  self.tokenizer_instance,
257
  self.graph_storage,
258
+ partition_config["method_params"],
259
  self.text_chunks_storage,
260
  self.progress_bar,
261
  )
262
+ elif mode == "aggregated":
263
  results = await traverse_graph_for_aggregated(
264
  self.synthesizer_llm_client,
265
  self.tokenizer_instance,
266
  self.graph_storage,
267
+ partition_config["method_params"],
268
  self.text_chunks_storage,
269
  self.progress_bar,
270
  )
271
+ elif mode == "cot":
272
+ results = await generate_cot(
273
+ self.graph_storage,
274
+ self.synthesizer_llm_client,
275
+ method_params=partition_config["method_params"],
276
+ )
277
  else:
278
+ raise ValueError(f"Unknown generation mode: {mode}")
279
+ # Step 2: generate QA pairs
280
+ # TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ # Step 3: format
283
  results = format_generation_results(
284
+ results, output_data_format=generate_config["data_format"]
285
  )
286
 
287
  await self.qa_storage.upsert(results)
graphgen/models/__init__.py CHANGED
@@ -13,5 +13,4 @@ from .search.web.google_search import GoogleSearch
13
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
14
  from .storage.json_storage import JsonKVStorage, JsonListStorage
15
  from .storage.networkx_storage import NetworkXStorage
16
- from .strategy.travserse_strategy import TraverseStrategy
17
  from .tokenizer import Tokenizer
 
13
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
14
  from .storage.json_storage import JsonKVStorage, JsonListStorage
15
  from .storage.networkx_storage import NetworkXStorage
 
16
  from .tokenizer import Tokenizer
graphgen/models/strategy/__init__.py DELETED
File without changes
graphgen/models/strategy/travserse_strategy.py DELETED
@@ -1,28 +0,0 @@
1
- from dataclasses import dataclass, fields
2
-
3
-
4
- @dataclass
5
- class TraverseStrategy:
6
- # 生成的QA形式:原子、多跳、聚合型
7
- qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
8
- # 最大边数和最大token数方法中选择一个生效
9
- expand_method: str = "max_tokens" # "max_width" or "max_tokens"
10
- # 单向拓展还是双向拓展
11
- bidirectional: bool = True
12
- # 每个方向拓展的最大边数
13
- max_extra_edges: int = 5
14
- # 最长token数
15
- max_tokens: int = 256
16
- # 每个方向拓展的最大深度
17
- max_depth: int = 2
18
- # 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
19
- edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
20
- # 孤立节点的处理策略
21
- isolated_node_strategy: str = "add" # "add" or "ignore"
22
- loss_strategy: str = "only_edge" # only_edge, both
23
-
24
- def to_yaml(self):
25
- strategy_dict = {}
26
- for f in fields(self):
27
- strategy_dict[f.name] = getattr(self, f.name)
28
- return {"traverse_strategy": strategy_dict}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/tokenizer/__init__.py CHANGED
@@ -39,6 +39,8 @@ class Tokenizer(BaseTokenizer):
39
  _impl: BaseTokenizer = field(init=False, repr=False)
40
 
41
  def __post_init__(self):
 
 
42
  self._impl = get_tokenizer_impl(self.model_name)
43
 
44
  def encode(self, text: str) -> List[int]:
 
39
  _impl: BaseTokenizer = field(init=False, repr=False)
40
 
41
  def __post_init__(self):
42
+ if not self.model_name:
43
+ raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
44
  self._impl = get_tokenizer_impl(self.model_name)
45
 
46
  def encode(self, text: str) -> List[int]:
graphgen/operators/build_kg/split_kg.py CHANGED
@@ -1,9 +1,10 @@
1
  import random
2
  from collections import defaultdict
 
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.models import NetworkXStorage, TraverseStrategy
7
  from graphgen.utils import logger
8
 
9
 
@@ -247,9 +248,9 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
247
  nodes: list,
248
  edges: list,
249
  graph_storage: NetworkXStorage,
250
- traverse_strategy: TraverseStrategy,
251
  ):
252
- expand_method = traverse_strategy.expand_method
253
  if expand_method == "max_width":
254
  logger.info("Using max width strategy")
255
  elif expand_method == "max_tokens":
@@ -257,8 +258,8 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
257
  else:
258
  raise ValueError(f"Invalid expand method: {expand_method}")
259
 
260
- max_depth = traverse_strategy.max_depth
261
- edge_sampling = traverse_strategy.edge_sampling
262
 
263
  # 构建临接矩阵
264
  edge_adj_list = defaultdict(list)
@@ -275,16 +276,16 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
275
  for i, (node_name, _) in enumerate(nodes):
276
  node_dict[node_name] = i
277
 
278
- if traverse_strategy.loss_strategy == "both":
279
  er_tuples = [
280
  ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
281
  for edge in edges
282
  ]
283
  edges = _sort_tuples(er_tuples, edge_sampling)
284
- elif traverse_strategy.loss_strategy == "only_edge":
285
  edges = _sort_edges(edges, edge_sampling)
286
  else:
287
- raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}")
288
 
289
  for i, (src, tgt, _) in enumerate(edges):
290
  edge_adj_list[src].append(i)
@@ -315,10 +316,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
315
  nodes,
316
  edge,
317
  max_depth,
318
- traverse_strategy.bidirectional,
319
- traverse_strategy.max_extra_edges,
320
  edge_sampling,
321
- traverse_strategy.loss_strategy,
322
  )
323
  else:
324
  level_n_edges = _get_level_n_edges_by_max_tokens(
@@ -328,10 +329,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
328
  nodes,
329
  edge,
330
  max_depth,
331
- traverse_strategy.bidirectional,
332
- traverse_strategy.max_tokens,
333
  edge_sampling,
334
- traverse_strategy.loss_strategy,
335
  )
336
 
337
  for _edge in level_n_edges:
@@ -352,7 +353,7 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
352
  logger.info("Processing batches: %d", len(processing_batches))
353
 
354
  # isolate nodes
355
- isolated_node_strategy = traverse_strategy.isolated_node_strategy
356
  if isolated_node_strategy == "add":
357
  processing_batches = await _add_isolated_nodes(
358
  nodes, processing_batches, graph_storage
 
1
  import random
2
  from collections import defaultdict
3
+ from typing import Dict
4
 
5
  from tqdm.asyncio import tqdm as tqdm_async
6
 
7
+ from graphgen.models import NetworkXStorage
8
  from graphgen.utils import logger
9
 
10
 
 
248
  nodes: list,
249
  edges: list,
250
  graph_storage: NetworkXStorage,
251
+ traverse_strategy: Dict,
252
  ):
253
+ expand_method = traverse_strategy["expand_method"]
254
  if expand_method == "max_width":
255
  logger.info("Using max width strategy")
256
  elif expand_method == "max_tokens":
 
258
  else:
259
  raise ValueError(f"Invalid expand method: {expand_method}")
260
 
261
+ max_depth = traverse_strategy["max_depth"]
262
+ edge_sampling = traverse_strategy["edge_sampling"]
263
 
264
  # 构建临接矩阵
265
  edge_adj_list = defaultdict(list)
 
276
  for i, (node_name, _) in enumerate(nodes):
277
  node_dict[node_name] = i
278
 
279
+ if traverse_strategy["loss_strategy"] == "both":
280
  er_tuples = [
281
  ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
282
  for edge in edges
283
  ]
284
  edges = _sort_tuples(er_tuples, edge_sampling)
285
+ elif traverse_strategy["loss_strategy"] == "only_edge":
286
  edges = _sort_edges(edges, edge_sampling)
287
  else:
288
+ raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}")
289
 
290
  for i, (src, tgt, _) in enumerate(edges):
291
  edge_adj_list[src].append(i)
 
316
  nodes,
317
  edge,
318
  max_depth,
319
+ traverse_strategy["bidirectional"],
320
+ traverse_strategy["max_extra_edges"],
321
  edge_sampling,
322
+ traverse_strategy["loss_strategy"],
323
  )
324
  else:
325
  level_n_edges = _get_level_n_edges_by_max_tokens(
 
329
  nodes,
330
  edge,
331
  max_depth,
332
+ traverse_strategy["bidirectional"],
333
+ traverse_strategy["max_tokens"],
334
  edge_sampling,
335
+ traverse_strategy["loss_strategy"],
336
  )
337
 
338
  for _edge in level_n_edges:
 
353
  logger.info("Processing batches: %d", len(processing_batches))
354
 
355
  # isolate nodes
356
+ isolated_node_strategy = traverse_strategy["isolated_node_strategy"]
357
  if isolated_node_strategy == "add":
358
  processing_batches = await _add_isolated_nodes(
359
  nodes, processing_batches, graph_storage
graphgen/operators/traverse_graph.py CHANGED
@@ -1,15 +1,10 @@
1
  import asyncio
 
2
 
3
  import gradio as gr
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.models import (
7
- JsonKVStorage,
8
- NetworkXStorage,
9
- OpenAIClient,
10
- Tokenizer,
11
- TraverseStrategy,
12
- )
13
  from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
14
  from graphgen.templates import (
15
  ANSWER_REPHRASING_PROMPT,
@@ -164,7 +159,7 @@ async def traverse_graph_for_aggregated(
164
  llm_client: OpenAIClient,
165
  tokenizer: Tokenizer,
166
  graph_storage: NetworkXStorage,
167
- traverse_strategy: TraverseStrategy,
168
  text_chunks_storage: JsonKVStorage,
169
  progress_bar: gr.Progress = None,
170
  max_concurrent: int = 1000,
@@ -240,7 +235,7 @@ async def traverse_graph_for_aggregated(
240
  "question": question,
241
  "answer": context,
242
  "loss": get_average_loss(
243
- _process_batch, traverse_strategy.loss_strategy
244
  ),
245
  }
246
  }
@@ -272,7 +267,7 @@ async def traverse_graph_for_aggregated(
272
  "question": qa["question"],
273
  "answer": qa["answer"],
274
  "loss": get_average_loss(
275
- _process_batch, traverse_strategy.loss_strategy
276
  ),
277
  }
278
  return final_results
@@ -313,7 +308,7 @@ async def traverse_graph_for_atomic(
313
  llm_client: OpenAIClient,
314
  tokenizer: Tokenizer,
315
  graph_storage: NetworkXStorage,
316
- traverse_strategy: TraverseStrategy,
317
  text_chunks_storage: JsonKVStorage,
318
  progress_bar: gr.Progress = None,
319
  max_concurrent: int = 1000,
@@ -331,7 +326,6 @@ async def traverse_graph_for_atomic(
331
  :return: question and answer
332
  """
333
 
334
- assert traverse_strategy.qa_form == "atomic"
335
  semaphore = asyncio.Semaphore(max_concurrent)
336
 
337
  def _parse_qa(qa: str) -> tuple:
@@ -429,7 +423,7 @@ async def traverse_graph_for_multi_hop(
429
  llm_client: OpenAIClient,
430
  tokenizer: Tokenizer,
431
  graph_storage: NetworkXStorage,
432
- traverse_strategy: TraverseStrategy,
433
  text_chunks_storage: JsonKVStorage,
434
  progress_bar: gr.Progress = None,
435
  max_concurrent: int = 1000,
@@ -517,7 +511,7 @@ async def traverse_graph_for_multi_hop(
517
  "question": question,
518
  "answer": answer,
519
  "loss": get_average_loss(
520
- _process_batch, traverse_strategy.loss_strategy
521
  ),
522
  }
523
  }
 
1
  import asyncio
2
+ from typing import Dict
3
 
4
  import gradio as gr
5
  from tqdm.asyncio import tqdm as tqdm_async
6
 
7
+ from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer
 
 
 
 
 
 
8
  from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
9
  from graphgen.templates import (
10
  ANSWER_REPHRASING_PROMPT,
 
159
  llm_client: OpenAIClient,
160
  tokenizer: Tokenizer,
161
  graph_storage: NetworkXStorage,
162
+ traverse_strategy: Dict,
163
  text_chunks_storage: JsonKVStorage,
164
  progress_bar: gr.Progress = None,
165
  max_concurrent: int = 1000,
 
235
  "question": question,
236
  "answer": context,
237
  "loss": get_average_loss(
238
+ _process_batch, traverse_strategy["loss_strategy"]
239
  ),
240
  }
241
  }
 
267
  "question": qa["question"],
268
  "answer": qa["answer"],
269
  "loss": get_average_loss(
270
+ _process_batch, traverse_strategy["loss_strategy"]
271
  ),
272
  }
273
  return final_results
 
308
  llm_client: OpenAIClient,
309
  tokenizer: Tokenizer,
310
  graph_storage: NetworkXStorage,
311
+ traverse_strategy: Dict,
312
  text_chunks_storage: JsonKVStorage,
313
  progress_bar: gr.Progress = None,
314
  max_concurrent: int = 1000,
 
326
  :return: question and answer
327
  """
328
 
 
329
  semaphore = asyncio.Semaphore(max_concurrent)
330
 
331
  def _parse_qa(qa: str) -> tuple:
 
423
  llm_client: OpenAIClient,
424
  tokenizer: Tokenizer,
425
  graph_storage: NetworkXStorage,
426
+ traverse_strategy: Dict,
427
  text_chunks_storage: JsonKVStorage,
428
  progress_bar: gr.Progress = None,
429
  max_concurrent: int = 1000,
 
511
  "question": question,
512
  "answer": answer,
513
  "loss": get_average_loss(
514
+ _process_batch, traverse_strategy["loss_strategy"]
515
  ),
516
  }
517
  }
webui/app.py CHANGED
@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
39
  set_logger(log_file, if_stream=True)
40
  os.environ.update({k: str(v) for k, v in env.items()})
41
 
42
- graph_gen = GraphGen(working_dir=working_dir, config=config)
43
- # Set up LLM clients
44
- graph_gen.synthesizer_llm_client = OpenAIClient(
45
  model_name=env.get("SYNTHESIZER_MODEL", ""),
46
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
47
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
48
  request_limit=True,
49
  rpm=RPM(env.get("RPM", 1000)),
50
  tpm=TPM(env.get("TPM", 50000)),
 
51
  )
52
-
53
- graph_gen.trainee_llm_client = OpenAIClient(
54
  model_name=env.get("TRAINEE_MODEL", ""),
55
  base_url=env.get("TRAINEE_BASE_URL", ""),
56
  api_key=env.get("TRAINEE_API_KEY", ""),
57
  request_limit=True,
58
  rpm=RPM(env.get("RPM", 1000)),
59
  tpm=TPM(env.get("TPM", 50000)),
 
60
  )
61
 
62
- graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
 
 
 
 
 
63
 
64
  return graph_gen
65
 
@@ -78,27 +83,32 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
78
  "chunk_size": params.chunk_size,
79
  "chunk_overlap": params.chunk_overlap,
80
  },
81
- "output_data_type": params.output_data_type,
82
- "output_data_format": params.output_data_format,
83
- "tokenizer": params.tokenizer,
84
  "search": {"enabled": False},
85
- "quiz_and_judge_strategy": {
86
  "enabled": params.if_trainee_model,
87
  "quiz_samples": params.quiz_samples,
88
  },
89
- "traverse_strategy": {
90
- "bidirectional": params.bidirectional,
91
- "expand_method": params.expand_method,
92
- "max_extra_edges": params.max_extra_edges,
93
- "max_tokens": params.max_tokens,
94
- "max_depth": params.max_depth,
95
- "edge_sampling": params.edge_sampling,
96
- "isolated_node_strategy": params.isolated_node_strategy,
97
- "loss_strategy": params.loss_strategy,
 
 
 
 
 
 
 
98
  },
99
  }
100
 
101
  env = {
 
102
  "SYNTHESIZER_BASE_URL": params.synthesizer_url,
103
  "SYNTHESIZER_MODEL": params.synthesizer_model,
104
  "TRAINEE_BASE_URL": params.trainee_url,
@@ -128,19 +138,18 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
128
 
129
  try:
130
  # Process the data
131
- graph_gen.insert()
132
 
133
  if config["if_trainee_model"]:
134
- # Generate quiz
135
- graph_gen.quiz()
136
-
137
- # Judge statements
138
- graph_gen.judge()
139
  else:
140
- graph_gen.traverse_strategy.edge_sampling = "random"
141
 
142
- # Traverse graph
143
- graph_gen.traverse()
 
 
144
 
145
  # Save output
146
  output_data = graph_gen.qa_storage.data
@@ -249,6 +258,9 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
249
  )
250
 
251
  with gr.Accordion(label=_("Model Config"), open=False):
 
 
 
252
  synthesizer_url = gr.Textbox(
253
  label="Synthesizer URL",
254
  value="https://api.siliconflow.cn/v1",
@@ -300,9 +312,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
300
  step=100,
301
  interactive=True,
302
  )
303
- tokenizer = gr.Textbox(
304
- label="Tokenizer", value="cl100k_base", interactive=True
305
- )
306
  output_data_type = gr.Radio(
307
  choices=["atomic", "multi_hop", "aggregated"],
308
  label="Output Data Type",
 
39
  set_logger(log_file, if_stream=True)
40
  os.environ.update({k: str(v) for k, v in env.items()})
41
 
42
+ tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
43
+ synthesizer_llm_client = OpenAIClient(
 
44
  model_name=env.get("SYNTHESIZER_MODEL", ""),
45
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
46
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
47
  request_limit=True,
48
  rpm=RPM(env.get("RPM", 1000)),
49
  tpm=TPM(env.get("TPM", 50000)),
50
+ tokenizer=tokenizer_instance,
51
  )
52
+ trainee_llm_client = OpenAIClient(
 
53
  model_name=env.get("TRAINEE_MODEL", ""),
54
  base_url=env.get("TRAINEE_BASE_URL", ""),
55
  api_key=env.get("TRAINEE_API_KEY", ""),
56
  request_limit=True,
57
  rpm=RPM(env.get("RPM", 1000)),
58
  tpm=TPM(env.get("TPM", 50000)),
59
+ tokenizer=tokenizer_instance,
60
  )
61
 
62
+ graph_gen = GraphGen(
63
+ working_dir=working_dir,
64
+ tokenizer_instance=tokenizer_instance,
65
+ synthesizer_llm_client=synthesizer_llm_client,
66
+ trainee_llm_client=trainee_llm_client,
67
+ )
68
 
69
  return graph_gen
70
 
 
83
  "chunk_size": params.chunk_size,
84
  "chunk_overlap": params.chunk_overlap,
85
  },
 
 
 
86
  "search": {"enabled": False},
87
+ "quiz_and_judge": {
88
  "enabled": params.if_trainee_model,
89
  "quiz_samples": params.quiz_samples,
90
  },
91
+ "partition": {
92
+ "method": "ece",
93
+ "method_params": {
94
+ "bidirectional": params.bidirectional,
95
+ "expand_method": params.expand_method,
96
+ "max_extra_edges": params.max_extra_edges,
97
+ "max_tokens": params.max_tokens,
98
+ "max_depth": params.max_depth,
99
+ "edge_sampling": params.edge_sampling,
100
+ "isolated_node_strategy": params.isolated_node_strategy,
101
+ "loss_strategy": params.loss_strategy,
102
+ },
103
+ },
104
+ "generate": {
105
+ "mode": params.output_data_type,
106
+ "data_format": params.output_data_format,
107
  },
108
  }
109
 
110
  env = {
111
+ "TOKENIZER_MODEL": params.tokenizer,
112
  "SYNTHESIZER_BASE_URL": params.synthesizer_url,
113
  "SYNTHESIZER_MODEL": params.synthesizer_model,
114
  "TRAINEE_BASE_URL": params.trainee_url,
 
138
 
139
  try:
140
  # Process the data
141
+ graph_gen.insert(read_config=config["read"], split_config=config["split"])
142
 
143
  if config["if_trainee_model"]:
144
+ # Quiz and Judge
145
+ graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
 
 
 
146
  else:
147
+ config["partition"]["method_params"]["edge_sampling"] = "random"
148
 
149
+ graph_gen.generate(
150
+ partition_config=config["partition"],
151
+ generate_config=config["generate"],
152
+ )
153
 
154
  # Save output
155
  output_data = graph_gen.qa_storage.data
 
258
  )
259
 
260
  with gr.Accordion(label=_("Model Config"), open=False):
261
+ tokenizer = gr.Textbox(
262
+ label="Tokenizer", value="cl100k_base", interactive=True
263
+ )
264
  synthesizer_url = gr.Textbox(
265
  label="Synthesizer URL",
266
  value="https://api.siliconflow.cn/v1",
 
312
  step=100,
313
  interactive=True,
314
  )
 
 
 
315
  output_data_type = gr.Radio(
316
  choices=["atomic", "multi_hop", "aggregated"],
317
  label="Output Data Type",