Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
817f16e
1
Parent(s):
3a3b216
Auto-sync from demo at Tue Sep 30 07:59:12 UTC 2025
Browse files- app.py +40 -31
- graphgen/configs/aggregated_config.yaml +15 -13
- graphgen/configs/atomic_config.yaml +15 -13
- graphgen/configs/cot_config.yaml +11 -8
- graphgen/configs/multi_hop_config.yaml +15 -13
- graphgen/generate.py +26 -25
- graphgen/graphgen.py +56 -76
- graphgen/models/__init__.py +0 -1
- graphgen/models/strategy/__init__.py +0 -0
- graphgen/models/strategy/travserse_strategy.py +0 -28
- graphgen/models/tokenizer/__init__.py +2 -0
- graphgen/operators/build_kg/split_kg.py +16 -15
- graphgen/operators/traverse_graph.py +8 -14
- webui/app.py +40 -31
app.py
CHANGED
@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
39 |
set_logger(log_file, if_stream=True)
|
40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
graph_gen.synthesizer_llm_client = OpenAIClient(
|
45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
48 |
request_limit=True,
|
49 |
rpm=RPM(env.get("RPM", 1000)),
|
50 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
51 |
)
|
52 |
-
|
53 |
-
graph_gen.trainee_llm_client = OpenAIClient(
|
54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
57 |
request_limit=True,
|
58 |
rpm=RPM(env.get("RPM", 1000)),
|
59 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
60 |
)
|
61 |
|
62 |
-
graph_gen
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
return graph_gen
|
65 |
|
@@ -78,27 +83,32 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
78 |
"chunk_size": params.chunk_size,
|
79 |
"chunk_overlap": params.chunk_overlap,
|
80 |
},
|
81 |
-
"output_data_type": params.output_data_type,
|
82 |
-
"output_data_format": params.output_data_format,
|
83 |
-
"tokenizer": params.tokenizer,
|
84 |
"search": {"enabled": False},
|
85 |
-
"
|
86 |
"enabled": params.if_trainee_model,
|
87 |
"quiz_samples": params.quiz_samples,
|
88 |
},
|
89 |
-
"
|
90 |
-
"
|
91 |
-
"
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
},
|
99 |
}
|
100 |
|
101 |
env = {
|
|
|
102 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
103 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
104 |
"TRAINEE_BASE_URL": params.trainee_url,
|
@@ -128,19 +138,18 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
128 |
|
129 |
try:
|
130 |
# Process the data
|
131 |
-
graph_gen.insert()
|
132 |
|
133 |
if config["if_trainee_model"]:
|
134 |
-
#
|
135 |
-
graph_gen.
|
136 |
-
|
137 |
-
# Judge statements
|
138 |
-
graph_gen.judge()
|
139 |
else:
|
140 |
-
|
141 |
|
142 |
-
|
143 |
-
|
|
|
|
|
144 |
|
145 |
# Save output
|
146 |
output_data = graph_gen.qa_storage.data
|
@@ -249,6 +258,9 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
249 |
)
|
250 |
|
251 |
with gr.Accordion(label=_("Model Config"), open=False):
|
|
|
|
|
|
|
252 |
synthesizer_url = gr.Textbox(
|
253 |
label="Synthesizer URL",
|
254 |
value="https://api.siliconflow.cn/v1",
|
@@ -300,9 +312,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
300 |
step=100,
|
301 |
interactive=True,
|
302 |
)
|
303 |
-
tokenizer = gr.Textbox(
|
304 |
-
label="Tokenizer", value="cl100k_base", interactive=True
|
305 |
-
)
|
306 |
output_data_type = gr.Radio(
|
307 |
choices=["atomic", "multi_hop", "aggregated"],
|
308 |
label="Output Data Type",
|
|
|
39 |
set_logger(log_file, if_stream=True)
|
40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
41 |
|
42 |
+
tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
43 |
+
synthesizer_llm_client = OpenAIClient(
|
|
|
44 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
45 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
46 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
47 |
request_limit=True,
|
48 |
rpm=RPM(env.get("RPM", 1000)),
|
49 |
tpm=TPM(env.get("TPM", 50000)),
|
50 |
+
tokenizer=tokenizer_instance,
|
51 |
)
|
52 |
+
trainee_llm_client = OpenAIClient(
|
|
|
53 |
model_name=env.get("TRAINEE_MODEL", ""),
|
54 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
55 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
56 |
request_limit=True,
|
57 |
rpm=RPM(env.get("RPM", 1000)),
|
58 |
tpm=TPM(env.get("TPM", 50000)),
|
59 |
+
tokenizer=tokenizer_instance,
|
60 |
)
|
61 |
|
62 |
+
graph_gen = GraphGen(
|
63 |
+
working_dir=working_dir,
|
64 |
+
tokenizer_instance=tokenizer_instance,
|
65 |
+
synthesizer_llm_client=synthesizer_llm_client,
|
66 |
+
trainee_llm_client=trainee_llm_client,
|
67 |
+
)
|
68 |
|
69 |
return graph_gen
|
70 |
|
|
|
83 |
"chunk_size": params.chunk_size,
|
84 |
"chunk_overlap": params.chunk_overlap,
|
85 |
},
|
|
|
|
|
|
|
86 |
"search": {"enabled": False},
|
87 |
+
"quiz_and_judge": {
|
88 |
"enabled": params.if_trainee_model,
|
89 |
"quiz_samples": params.quiz_samples,
|
90 |
},
|
91 |
+
"partition": {
|
92 |
+
"method": "ece",
|
93 |
+
"method_params": {
|
94 |
+
"bidirectional": params.bidirectional,
|
95 |
+
"expand_method": params.expand_method,
|
96 |
+
"max_extra_edges": params.max_extra_edges,
|
97 |
+
"max_tokens": params.max_tokens,
|
98 |
+
"max_depth": params.max_depth,
|
99 |
+
"edge_sampling": params.edge_sampling,
|
100 |
+
"isolated_node_strategy": params.isolated_node_strategy,
|
101 |
+
"loss_strategy": params.loss_strategy,
|
102 |
+
},
|
103 |
+
},
|
104 |
+
"generate": {
|
105 |
+
"mode": params.output_data_type,
|
106 |
+
"data_format": params.output_data_format,
|
107 |
},
|
108 |
}
|
109 |
|
110 |
env = {
|
111 |
+
"TOKENIZER_MODEL": params.tokenizer,
|
112 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
113 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
114 |
"TRAINEE_BASE_URL": params.trainee_url,
|
|
|
138 |
|
139 |
try:
|
140 |
# Process the data
|
141 |
+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
|
142 |
|
143 |
if config["if_trainee_model"]:
|
144 |
+
# Quiz and Judge
|
145 |
+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
|
|
|
|
|
|
|
146 |
else:
|
147 |
+
config["partition"]["method_params"]["edge_sampling"] = "random"
|
148 |
|
149 |
+
graph_gen.generate(
|
150 |
+
partition_config=config["partition"],
|
151 |
+
generate_config=config["generate"],
|
152 |
+
)
|
153 |
|
154 |
# Save output
|
155 |
output_data = graph_gen.qa_storage.data
|
|
|
258 |
)
|
259 |
|
260 |
with gr.Accordion(label=_("Model Config"), open=False):
|
261 |
+
tokenizer = gr.Textbox(
|
262 |
+
label="Tokenizer", value="cl100k_base", interactive=True
|
263 |
+
)
|
264 |
synthesizer_url = gr.Textbox(
|
265 |
label="Synthesizer URL",
|
266 |
value="https://api.siliconflow.cn/v1",
|
|
|
312 |
step=100,
|
313 |
interactive=True,
|
314 |
)
|
|
|
|
|
|
|
315 |
output_data_type = gr.Radio(
|
316 |
choices=["atomic", "multi_hop", "aggregated"],
|
317 |
label="Output Data Type",
|
graphgen/configs/aggregated_config.yaml
CHANGED
@@ -6,19 +6,21 @@ split:
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
-
|
10 |
-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
11 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
12 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
13 |
enabled: true
|
14 |
quiz_samples: 2 # number of quiz samples to generate
|
15 |
re_judge: false # whether to re-judge the existing quiz samples
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
|
|
|
|
|
10 |
enabled: true
|
11 |
quiz_samples: 2 # number of quiz samples to generate
|
12 |
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
partition: # graph partition configuration
|
14 |
+
method: ece # ece is a custom partition method based on comprehension loss
|
15 |
+
method_params:
|
16 |
+
bidirectional: true # whether to traverse the graph in both directions
|
17 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
18 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
19 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
20 |
+
max_depth: 5 # maximum depth for graph traversal
|
21 |
+
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
|
22 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
23 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
24 |
+
generate:
|
25 |
+
mode: aggregated # atomic, aggregated, multi_hop, cot
|
26 |
+
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
graphgen/configs/atomic_config.yaml
CHANGED
@@ -6,19 +6,21 @@ split:
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
-
|
10 |
-
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
11 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
12 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
13 |
enabled: true
|
14 |
quiz_samples: 2 # number of quiz samples to generate
|
15 |
re_judge: false # whether to re-judge the existing quiz samples
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
|
|
|
|
|
10 |
enabled: true
|
11 |
quiz_samples: 2 # number of quiz samples to generate
|
12 |
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
partition: # graph partition configuration
|
14 |
+
method: ece # ece is a custom partition method based on comprehension loss
|
15 |
+
method_params:
|
16 |
+
bidirectional: true # whether to traverse the graph in both directions
|
17 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
18 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
19 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
20 |
+
max_depth: 3 # maximum depth for graph traversal
|
21 |
+
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
|
22 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
23 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
24 |
+
generate:
|
25 |
+
mode: atomic # atomic, aggregated, multi_hop, cot
|
26 |
+
data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
graphgen/configs/cot_config.yaml
CHANGED
@@ -6,11 +6,14 @@ split:
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
10 |
+
enabled: false
|
11 |
+
partition: # graph partition configuration
|
12 |
+
method: leiden # leiden is a community detection algorithm
|
13 |
+
method_params:
|
14 |
+
max_size: 20 # Maximum size of communities
|
15 |
+
use_lcc: false
|
16 |
+
random_seed: 42
|
17 |
+
generate:
|
18 |
+
mode: cot # atomic, aggregated, multi_hop, cot
|
19 |
+
data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
graphgen/configs/multi_hop_config.yaml
CHANGED
@@ -6,19 +6,21 @@ split:
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
-
|
10 |
-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
11 |
-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
12 |
-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
13 |
enabled: false
|
14 |
quiz_samples: 2 # number of quiz samples to generate
|
15 |
re_judge: false # whether to re-judge the existing quiz samples
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
6 |
search: # web search configuration
|
7 |
enabled: false # whether to enable web search
|
8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
|
|
|
|
|
10 |
enabled: false
|
11 |
quiz_samples: 2 # number of quiz samples to generate
|
12 |
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
partition: # graph partition configuration
|
14 |
+
method: ece # ece is a custom partition method based on comprehension loss
|
15 |
+
method_params:
|
16 |
+
bidirectional: true # whether to traverse the graph in both directions
|
17 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
18 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
19 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
20 |
+
max_depth: 1 # maximum depth for graph traversal
|
21 |
+
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
|
22 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
23 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
24 |
+
generate:
|
25 |
+
mode: multi_hop # strategy for generating multi-hop QA pairs
|
26 |
+
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
graphgen/generate.py
CHANGED
@@ -6,8 +6,8 @@ from importlib.resources import files
|
|
6 |
import yaml
|
7 |
from dotenv import load_dotenv
|
8 |
|
9 |
-
from .graphgen import GraphGen
|
10 |
-
from .utils import logger, set_logger
|
11 |
|
12 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
13 |
|
@@ -50,50 +50,51 @@ def main():
|
|
50 |
with open(args.config_file, "r", encoding="utf-8") as f:
|
51 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
52 |
|
53 |
-
|
54 |
unique_id = int(time.time())
|
55 |
|
56 |
-
output_path = os.path.join(
|
57 |
-
working_dir, "data", "graphgen", f"{unique_id}_{output_data_type}"
|
58 |
-
)
|
59 |
set_working_dir(output_path)
|
60 |
|
61 |
set_logger(
|
62 |
-
os.path.join(output_path, f"{unique_id}.log"),
|
63 |
if_stream=True,
|
64 |
)
|
65 |
logger.info(
|
66 |
"GraphGen with unique ID %s logging to %s",
|
67 |
unique_id,
|
68 |
-
os.path.join(
|
69 |
-
working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log"
|
70 |
-
),
|
71 |
)
|
72 |
|
73 |
-
graph_gen = GraphGen(
|
74 |
|
75 |
-
graph_gen.insert()
|
76 |
|
77 |
-
|
78 |
-
graph_gen.search()
|
79 |
|
80 |
# Use pipeline according to the output data type
|
81 |
-
if
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
graph_gen.quiz()
|
86 |
-
graph_gen.judge()
|
87 |
else:
|
88 |
logger.warning(
|
89 |
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
|
90 |
)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
95 |
else:
|
96 |
-
raise ValueError(f"Unsupported output data type: {
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
save_config(os.path.join(output_path, "config.yaml"), config)
|
99 |
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
|
|
6 |
import yaml
|
7 |
from dotenv import load_dotenv
|
8 |
|
9 |
+
from graphgen.graphgen import GraphGen
|
10 |
+
from graphgen.utils import logger, set_logger
|
11 |
|
12 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
13 |
|
|
|
50 |
with open(args.config_file, "r", encoding="utf-8") as f:
|
51 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
52 |
|
53 |
+
mode = config["generate"]["mode"]
|
54 |
unique_id = int(time.time())
|
55 |
|
56 |
+
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
|
|
|
|
|
57 |
set_working_dir(output_path)
|
58 |
|
59 |
set_logger(
|
60 |
+
os.path.join(output_path, f"{unique_id}_{mode}.log"),
|
61 |
if_stream=True,
|
62 |
)
|
63 |
logger.info(
|
64 |
"GraphGen with unique ID %s logging to %s",
|
65 |
unique_id,
|
66 |
+
os.path.join(working_dir, f"{unique_id}_{mode}.log"),
|
|
|
|
|
67 |
)
|
68 |
|
69 |
+
graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
|
70 |
|
71 |
+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
|
72 |
|
73 |
+
graph_gen.search(search_config=config["search"])
|
|
|
74 |
|
75 |
# Use pipeline according to the output data type
|
76 |
+
if mode in ["atomic", "aggregated", "multi_hop"]:
|
77 |
+
logger.info("Generation mode set to '%s'. Start generation.", mode)
|
78 |
+
if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]:
|
79 |
+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
|
|
|
|
|
80 |
else:
|
81 |
logger.warning(
|
82 |
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
|
83 |
)
|
84 |
+
assert (
|
85 |
+
config["partition"]["method"] == "ece"
|
86 |
+
and "ece_params" in config["partition"]
|
87 |
+
), "Only ECE partition with edge sampling is supported."
|
88 |
+
config["partition"]["method_params"]["edge_sampling"] = "random"
|
89 |
+
elif mode == "cot":
|
90 |
+
logger.info("Generation mode set to 'cot'. Start generation.")
|
91 |
else:
|
92 |
+
raise ValueError(f"Unsupported output data type: {mode}")
|
93 |
+
|
94 |
+
graph_gen.generate(
|
95 |
+
partition_config=config["partition"],
|
96 |
+
generate_config=config["generate"],
|
97 |
+
)
|
98 |
|
99 |
save_config(os.path.join(output_path, "config.yaml"), config)
|
100 |
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
graphgen/graphgen.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
import time
|
4 |
-
from dataclasses import dataclass
|
5 |
from typing import Dict, cast
|
6 |
|
7 |
import gradio as gr
|
@@ -14,7 +14,6 @@ from graphgen.models import (
|
|
14 |
NetworkXStorage,
|
15 |
OpenAIClient,
|
16 |
Tokenizer,
|
17 |
-
TraverseStrategy,
|
18 |
)
|
19 |
from graphgen.operators import (
|
20 |
chunk_documents,
|
@@ -42,46 +41,36 @@ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
42 |
class GraphGen:
|
43 |
unique_id: int = int(time.time())
|
44 |
working_dir: str = os.path.join(sys_path, "cache")
|
45 |
-
config: Dict = field(default_factory=dict)
|
46 |
|
47 |
# llm
|
48 |
tokenizer_instance: Tokenizer = None
|
49 |
synthesizer_llm_client: OpenAIClient = None
|
50 |
trainee_llm_client: OpenAIClient = None
|
51 |
|
52 |
-
# search
|
53 |
-
search_config: dict = field(
|
54 |
-
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
|
55 |
-
)
|
56 |
-
|
57 |
-
# traversal
|
58 |
-
traverse_strategy: TraverseStrategy = None
|
59 |
-
|
60 |
# webui
|
61 |
progress_bar: gr.Progress = None
|
62 |
|
63 |
def __post_init__(self):
|
64 |
-
self.tokenizer_instance: Tokenizer = Tokenizer(
|
65 |
-
model_name=
|
66 |
)
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
)
|
73 |
-
|
|
|
74 |
model_name=os.getenv("TRAINEE_MODEL"),
|
75 |
api_key=os.getenv("TRAINEE_API_KEY"),
|
76 |
base_url=os.getenv("TRAINEE_BASE_URL"),
|
77 |
tokenizer=self.tokenizer_instance,
|
78 |
)
|
79 |
-
self.search_config = self.config["search"]
|
80 |
-
|
81 |
-
if "traverse_strategy" in self.config:
|
82 |
-
self.traverse_strategy = TraverseStrategy(
|
83 |
-
**self.config["traverse_strategy"]
|
84 |
-
)
|
85 |
|
86 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
87 |
self.working_dir, namespace="full_docs"
|
@@ -99,24 +88,17 @@ class GraphGen:
|
|
99 |
self.working_dir, namespace="rephrase"
|
100 |
)
|
101 |
self.qa_storage: JsonListStorage = JsonListStorage(
|
102 |
-
os.path.join(
|
103 |
-
self.working_dir,
|
104 |
-
"data",
|
105 |
-
"graphgen",
|
106 |
-
f"{self.unique_id}_{self.config['output_data_type']}",
|
107 |
-
),
|
108 |
namespace="qa",
|
109 |
)
|
110 |
|
111 |
@async_to_sync_method
|
112 |
-
async def insert(self):
|
113 |
"""
|
114 |
insert chunks into the graph
|
115 |
"""
|
116 |
-
input_file = self.config["read"]["input_file"]
|
117 |
-
|
118 |
# Step 1: Read files
|
119 |
-
data = read_files(input_file)
|
120 |
if len(data) == 0:
|
121 |
logger.warning("No data to process")
|
122 |
return
|
@@ -141,8 +123,8 @@ class GraphGen:
|
|
141 |
|
142 |
inserting_chunks = await chunk_documents(
|
143 |
new_docs,
|
144 |
-
|
145 |
-
|
146 |
self.tokenizer_instance,
|
147 |
self.progress_bar,
|
148 |
)
|
@@ -178,6 +160,7 @@ class GraphGen:
|
|
178 |
return
|
179 |
|
180 |
await self._insert_done()
|
|
|
181 |
|
182 |
async def _insert_done(self):
|
183 |
tasks = []
|
@@ -193,14 +176,12 @@ class GraphGen:
|
|
193 |
await asyncio.gather(*tasks)
|
194 |
|
195 |
@async_to_sync_method
|
196 |
-
async def search(self):
|
197 |
logger.info(
|
198 |
-
"Search is %s", "enabled" if
|
199 |
)
|
200 |
-
if
|
201 |
-
logger.info(
|
202 |
-
"[Search] %s ...", ", ".join(self.search_config["search_types"])
|
203 |
-
)
|
204 |
all_nodes = await self.graph_storage.get_all_nodes()
|
205 |
all_nodes_names = [node[0] for node in all_nodes]
|
206 |
new_search_entities = await self.full_docs_storage.filter_keys(
|
@@ -210,7 +191,7 @@ class GraphGen:
|
|
210 |
"[Search] Found %d entities to search", len(new_search_entities)
|
211 |
)
|
212 |
_add_search_data = await search_all(
|
213 |
-
search_types=
|
214 |
search_entities=new_search_entities,
|
215 |
)
|
216 |
if _add_search_data:
|
@@ -230,78 +211,77 @@ class GraphGen:
|
|
230 |
await self.insert()
|
231 |
|
232 |
@async_to_sync_method
|
233 |
-
async def
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
235 |
await quiz(
|
236 |
self.synthesizer_llm_client,
|
237 |
self.graph_storage,
|
238 |
self.rephrase_storage,
|
239 |
max_samples,
|
240 |
)
|
241 |
-
await self.rephrase_storage.index_done_callback()
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
|
246 |
_update_relations = await judge_statement(
|
247 |
self.trainee_llm_client,
|
248 |
self.graph_storage,
|
249 |
self.rephrase_storage,
|
250 |
re_judge,
|
251 |
)
|
|
|
252 |
await _update_relations.index_done_callback()
|
253 |
|
254 |
@async_to_sync_method
|
255 |
-
async def
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
259 |
results = await traverse_graph_for_atomic(
|
260 |
self.synthesizer_llm_client,
|
261 |
self.tokenizer_instance,
|
262 |
self.graph_storage,
|
263 |
-
|
264 |
self.text_chunks_storage,
|
265 |
self.progress_bar,
|
266 |
)
|
267 |
-
elif
|
268 |
results = await traverse_graph_for_multi_hop(
|
269 |
self.synthesizer_llm_client,
|
270 |
self.tokenizer_instance,
|
271 |
self.graph_storage,
|
272 |
-
|
273 |
self.text_chunks_storage,
|
274 |
self.progress_bar,
|
275 |
)
|
276 |
-
elif
|
277 |
results = await traverse_graph_for_aggregated(
|
278 |
self.synthesizer_llm_client,
|
279 |
self.tokenizer_instance,
|
280 |
self.graph_storage,
|
281 |
-
|
282 |
self.text_chunks_storage,
|
283 |
self.progress_bar,
|
284 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
else:
|
286 |
-
raise ValueError(f"Unknown
|
287 |
-
|
288 |
-
|
289 |
-
results, output_data_format=self.config["output_data_format"]
|
290 |
-
)
|
291 |
-
|
292 |
-
await self.qa_storage.upsert(results)
|
293 |
-
await self.qa_storage.index_done_callback()
|
294 |
-
|
295 |
-
@async_to_sync_method
|
296 |
-
async def generate_reasoning(self, method_params):
|
297 |
-
results = await generate_cot(
|
298 |
-
self.graph_storage,
|
299 |
-
self.synthesizer_llm_client,
|
300 |
-
method_params=method_params,
|
301 |
-
)
|
302 |
|
|
|
303 |
results = format_generation_results(
|
304 |
-
results, output_data_format=
|
305 |
)
|
306 |
|
307 |
await self.qa_storage.upsert(results)
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
import time
|
4 |
+
from dataclasses import dataclass
|
5 |
from typing import Dict, cast
|
6 |
|
7 |
import gradio as gr
|
|
|
14 |
NetworkXStorage,
|
15 |
OpenAIClient,
|
16 |
Tokenizer,
|
|
|
17 |
)
|
18 |
from graphgen.operators import (
|
19 |
chunk_documents,
|
|
|
41 |
class GraphGen:
|
42 |
unique_id: int = int(time.time())
|
43 |
working_dir: str = os.path.join(sys_path, "cache")
|
|
|
44 |
|
45 |
# llm
|
46 |
tokenizer_instance: Tokenizer = None
|
47 |
synthesizer_llm_client: OpenAIClient = None
|
48 |
trainee_llm_client: OpenAIClient = None
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# webui
|
51 |
progress_bar: gr.Progress = None
|
52 |
|
53 |
def __post_init__(self):
|
54 |
+
self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
|
55 |
+
model_name=os.getenv("TOKENIZER_MODEL")
|
56 |
)
|
57 |
+
|
58 |
+
self.synthesizer_llm_client: OpenAIClient = (
|
59 |
+
self.synthesizer_llm_client
|
60 |
+
or OpenAIClient(
|
61 |
+
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
62 |
+
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
63 |
+
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
64 |
+
tokenizer=self.tokenizer_instance,
|
65 |
+
)
|
66 |
)
|
67 |
+
|
68 |
+
self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
|
69 |
model_name=os.getenv("TRAINEE_MODEL"),
|
70 |
api_key=os.getenv("TRAINEE_API_KEY"),
|
71 |
base_url=os.getenv("TRAINEE_BASE_URL"),
|
72 |
tokenizer=self.tokenizer_instance,
|
73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
76 |
self.working_dir, namespace="full_docs"
|
|
|
88 |
self.working_dir, namespace="rephrase"
|
89 |
)
|
90 |
self.qa_storage: JsonListStorage = JsonListStorage(
|
91 |
+
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
|
|
|
|
|
|
|
|
|
|
|
92 |
namespace="qa",
|
93 |
)
|
94 |
|
95 |
@async_to_sync_method
|
96 |
+
async def insert(self, read_config: Dict, split_config: Dict):
|
97 |
"""
|
98 |
insert chunks into the graph
|
99 |
"""
|
|
|
|
|
100 |
# Step 1: Read files
|
101 |
+
data = read_files(read_config["input_file"])
|
102 |
if len(data) == 0:
|
103 |
logger.warning("No data to process")
|
104 |
return
|
|
|
123 |
|
124 |
inserting_chunks = await chunk_documents(
|
125 |
new_docs,
|
126 |
+
split_config["chunk_size"],
|
127 |
+
split_config["chunk_overlap"],
|
128 |
self.tokenizer_instance,
|
129 |
self.progress_bar,
|
130 |
)
|
|
|
160 |
return
|
161 |
|
162 |
await self._insert_done()
|
163 |
+
return _add_entities_and_relations
|
164 |
|
165 |
async def _insert_done(self):
|
166 |
tasks = []
|
|
|
176 |
await asyncio.gather(*tasks)
|
177 |
|
178 |
@async_to_sync_method
|
179 |
+
async def search(self, search_config: Dict):
|
180 |
logger.info(
|
181 |
+
"Search is %s", "enabled" if search_config["enabled"] else "disabled"
|
182 |
)
|
183 |
+
if search_config["enabled"]:
|
184 |
+
logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
|
|
|
|
|
185 |
all_nodes = await self.graph_storage.get_all_nodes()
|
186 |
all_nodes_names = [node[0] for node in all_nodes]
|
187 |
new_search_entities = await self.full_docs_storage.filter_keys(
|
|
|
191 |
"[Search] Found %d entities to search", len(new_search_entities)
|
192 |
)
|
193 |
_add_search_data = await search_all(
|
194 |
+
search_types=search_config["search_types"],
|
195 |
search_entities=new_search_entities,
|
196 |
)
|
197 |
if _add_search_data:
|
|
|
211 |
await self.insert()
|
212 |
|
213 |
@async_to_sync_method
|
214 |
+
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
|
215 |
+
if quiz_and_judge_config is None or not quiz_and_judge_config.get(
|
216 |
+
"enabled", False
|
217 |
+
):
|
218 |
+
logger.warning("Quiz and Judge is not used in this pipeline.")
|
219 |
+
return
|
220 |
+
max_samples = quiz_and_judge_config["quiz_samples"]
|
221 |
await quiz(
|
222 |
self.synthesizer_llm_client,
|
223 |
self.graph_storage,
|
224 |
self.rephrase_storage,
|
225 |
max_samples,
|
226 |
)
|
|
|
227 |
|
228 |
+
# TODO: assert trainee_llm_client is valid before judge
|
229 |
+
re_judge = quiz_and_judge_config["re_judge"]
|
|
|
230 |
_update_relations = await judge_statement(
|
231 |
self.trainee_llm_client,
|
232 |
self.graph_storage,
|
233 |
self.rephrase_storage,
|
234 |
re_judge,
|
235 |
)
|
236 |
+
await self.rephrase_storage.index_done_callback()
|
237 |
await _update_relations.index_done_callback()
|
238 |
|
239 |
@async_to_sync_method
|
240 |
+
async def generate(self, partition_config: Dict, generate_config: Dict):
|
241 |
+
# Step 1: partition the graph
|
242 |
+
# TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
|
243 |
+
mode = generate_config["mode"]
|
244 |
+
if mode == "atomic":
|
245 |
results = await traverse_graph_for_atomic(
|
246 |
self.synthesizer_llm_client,
|
247 |
self.tokenizer_instance,
|
248 |
self.graph_storage,
|
249 |
+
partition_config["method_params"],
|
250 |
self.text_chunks_storage,
|
251 |
self.progress_bar,
|
252 |
)
|
253 |
+
elif mode == "multi_hop":
|
254 |
results = await traverse_graph_for_multi_hop(
|
255 |
self.synthesizer_llm_client,
|
256 |
self.tokenizer_instance,
|
257 |
self.graph_storage,
|
258 |
+
partition_config["method_params"],
|
259 |
self.text_chunks_storage,
|
260 |
self.progress_bar,
|
261 |
)
|
262 |
+
elif mode == "aggregated":
|
263 |
results = await traverse_graph_for_aggregated(
|
264 |
self.synthesizer_llm_client,
|
265 |
self.tokenizer_instance,
|
266 |
self.graph_storage,
|
267 |
+
partition_config["method_params"],
|
268 |
self.text_chunks_storage,
|
269 |
self.progress_bar,
|
270 |
)
|
271 |
+
elif mode == "cot":
|
272 |
+
results = await generate_cot(
|
273 |
+
self.graph_storage,
|
274 |
+
self.synthesizer_llm_client,
|
275 |
+
method_params=partition_config["method_params"],
|
276 |
+
)
|
277 |
else:
|
278 |
+
raise ValueError(f"Unknown generation mode: {mode}")
|
279 |
+
# Step 2: generate QA pairs
|
280 |
+
# TODO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
+
# Step 3: format
|
283 |
results = format_generation_results(
|
284 |
+
results, output_data_format=generate_config["data_format"]
|
285 |
)
|
286 |
|
287 |
await self.qa_storage.upsert(results)
|
graphgen/models/__init__.py
CHANGED
@@ -13,5 +13,4 @@ from .search.web.google_search import GoogleSearch
|
|
13 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
14 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
15 |
from .storage.networkx_storage import NetworkXStorage
|
16 |
-
from .strategy.travserse_strategy import TraverseStrategy
|
17 |
from .tokenizer import Tokenizer
|
|
|
13 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
14 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
15 |
from .storage.networkx_storage import NetworkXStorage
|
|
|
16 |
from .tokenizer import Tokenizer
|
graphgen/models/strategy/__init__.py
DELETED
File without changes
|
graphgen/models/strategy/travserse_strategy.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass, fields
|
2 |
-
|
3 |
-
|
4 |
-
@dataclass
|
5 |
-
class TraverseStrategy:
|
6 |
-
# 生成的QA形式:原子、多跳、聚合型
|
7 |
-
qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
|
8 |
-
# 最大边数和最大token数方法中选择一个生效
|
9 |
-
expand_method: str = "max_tokens" # "max_width" or "max_tokens"
|
10 |
-
# 单向拓展还是双向拓展
|
11 |
-
bidirectional: bool = True
|
12 |
-
# 每个方向拓展的最大边数
|
13 |
-
max_extra_edges: int = 5
|
14 |
-
# 最长token数
|
15 |
-
max_tokens: int = 256
|
16 |
-
# 每个方向拓展的最大深度
|
17 |
-
max_depth: int = 2
|
18 |
-
# 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
|
19 |
-
edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
|
20 |
-
# 孤立节点的处理策略
|
21 |
-
isolated_node_strategy: str = "add" # "add" or "ignore"
|
22 |
-
loss_strategy: str = "only_edge" # only_edge, both
|
23 |
-
|
24 |
-
def to_yaml(self):
|
25 |
-
strategy_dict = {}
|
26 |
-
for f in fields(self):
|
27 |
-
strategy_dict[f.name] = getattr(self, f.name)
|
28 |
-
return {"traverse_strategy": strategy_dict}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/tokenizer/__init__.py
CHANGED
@@ -39,6 +39,8 @@ class Tokenizer(BaseTokenizer):
|
|
39 |
_impl: BaseTokenizer = field(init=False, repr=False)
|
40 |
|
41 |
def __post_init__(self):
|
|
|
|
|
42 |
self._impl = get_tokenizer_impl(self.model_name)
|
43 |
|
44 |
def encode(self, text: str) -> List[int]:
|
|
|
39 |
_impl: BaseTokenizer = field(init=False, repr=False)
|
40 |
|
41 |
def __post_init__(self):
|
42 |
+
if not self.model_name:
|
43 |
+
raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.")
|
44 |
self._impl = get_tokenizer_impl(self.model_name)
|
45 |
|
46 |
def encode(self, text: str) -> List[int]:
|
graphgen/operators/build_kg/split_kg.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import random
|
2 |
from collections import defaultdict
|
|
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
-
from graphgen.models import NetworkXStorage
|
7 |
from graphgen.utils import logger
|
8 |
|
9 |
|
@@ -247,9 +248,9 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
247 |
nodes: list,
|
248 |
edges: list,
|
249 |
graph_storage: NetworkXStorage,
|
250 |
-
traverse_strategy:
|
251 |
):
|
252 |
-
expand_method = traverse_strategy
|
253 |
if expand_method == "max_width":
|
254 |
logger.info("Using max width strategy")
|
255 |
elif expand_method == "max_tokens":
|
@@ -257,8 +258,8 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
257 |
else:
|
258 |
raise ValueError(f"Invalid expand method: {expand_method}")
|
259 |
|
260 |
-
max_depth = traverse_strategy
|
261 |
-
edge_sampling = traverse_strategy
|
262 |
|
263 |
# 构建临接矩阵
|
264 |
edge_adj_list = defaultdict(list)
|
@@ -275,16 +276,16 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
275 |
for i, (node_name, _) in enumerate(nodes):
|
276 |
node_dict[node_name] = i
|
277 |
|
278 |
-
if traverse_strategy
|
279 |
er_tuples = [
|
280 |
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
281 |
for edge in edges
|
282 |
]
|
283 |
edges = _sort_tuples(er_tuples, edge_sampling)
|
284 |
-
elif traverse_strategy
|
285 |
edges = _sort_edges(edges, edge_sampling)
|
286 |
else:
|
287 |
-
raise ValueError(f"Invalid loss strategy: {traverse_strategy
|
288 |
|
289 |
for i, (src, tgt, _) in enumerate(edges):
|
290 |
edge_adj_list[src].append(i)
|
@@ -315,10 +316,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
315 |
nodes,
|
316 |
edge,
|
317 |
max_depth,
|
318 |
-
traverse_strategy
|
319 |
-
traverse_strategy
|
320 |
edge_sampling,
|
321 |
-
traverse_strategy
|
322 |
)
|
323 |
else:
|
324 |
level_n_edges = _get_level_n_edges_by_max_tokens(
|
@@ -328,10 +329,10 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
328 |
nodes,
|
329 |
edge,
|
330 |
max_depth,
|
331 |
-
traverse_strategy
|
332 |
-
traverse_strategy
|
333 |
edge_sampling,
|
334 |
-
traverse_strategy
|
335 |
)
|
336 |
|
337 |
for _edge in level_n_edges:
|
@@ -352,7 +353,7 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
|
352 |
logger.info("Processing batches: %d", len(processing_batches))
|
353 |
|
354 |
# isolate nodes
|
355 |
-
isolated_node_strategy = traverse_strategy
|
356 |
if isolated_node_strategy == "add":
|
357 |
processing_batches = await _add_isolated_nodes(
|
358 |
nodes, processing_batches, graph_storage
|
|
|
1 |
import random
|
2 |
from collections import defaultdict
|
3 |
+
from typing import Dict
|
4 |
|
5 |
from tqdm.asyncio import tqdm as tqdm_async
|
6 |
|
7 |
+
from graphgen.models import NetworkXStorage
|
8 |
from graphgen.utils import logger
|
9 |
|
10 |
|
|
|
248 |
nodes: list,
|
249 |
edges: list,
|
250 |
graph_storage: NetworkXStorage,
|
251 |
+
traverse_strategy: Dict,
|
252 |
):
|
253 |
+
expand_method = traverse_strategy["expand_method"]
|
254 |
if expand_method == "max_width":
|
255 |
logger.info("Using max width strategy")
|
256 |
elif expand_method == "max_tokens":
|
|
|
258 |
else:
|
259 |
raise ValueError(f"Invalid expand method: {expand_method}")
|
260 |
|
261 |
+
max_depth = traverse_strategy["max_depth"]
|
262 |
+
edge_sampling = traverse_strategy["edge_sampling"]
|
263 |
|
264 |
# 构建临接矩阵
|
265 |
edge_adj_list = defaultdict(list)
|
|
|
276 |
for i, (node_name, _) in enumerate(nodes):
|
277 |
node_dict[node_name] = i
|
278 |
|
279 |
+
if traverse_strategy["loss_strategy"] == "both":
|
280 |
er_tuples = [
|
281 |
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
282 |
for edge in edges
|
283 |
]
|
284 |
edges = _sort_tuples(er_tuples, edge_sampling)
|
285 |
+
elif traverse_strategy["loss_strategy"] == "only_edge":
|
286 |
edges = _sort_edges(edges, edge_sampling)
|
287 |
else:
|
288 |
+
raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}")
|
289 |
|
290 |
for i, (src, tgt, _) in enumerate(edges):
|
291 |
edge_adj_list[src].append(i)
|
|
|
316 |
nodes,
|
317 |
edge,
|
318 |
max_depth,
|
319 |
+
traverse_strategy["bidirectional"],
|
320 |
+
traverse_strategy["max_extra_edges"],
|
321 |
edge_sampling,
|
322 |
+
traverse_strategy["loss_strategy"],
|
323 |
)
|
324 |
else:
|
325 |
level_n_edges = _get_level_n_edges_by_max_tokens(
|
|
|
329 |
nodes,
|
330 |
edge,
|
331 |
max_depth,
|
332 |
+
traverse_strategy["bidirectional"],
|
333 |
+
traverse_strategy["max_tokens"],
|
334 |
edge_sampling,
|
335 |
+
traverse_strategy["loss_strategy"],
|
336 |
)
|
337 |
|
338 |
for _edge in level_n_edges:
|
|
|
353 |
logger.info("Processing batches: %d", len(processing_batches))
|
354 |
|
355 |
# isolate nodes
|
356 |
+
isolated_node_strategy = traverse_strategy["isolated_node_strategy"]
|
357 |
if isolated_node_strategy == "add":
|
358 |
processing_batches = await _add_isolated_nodes(
|
359 |
nodes, processing_batches, graph_storage
|
graphgen/operators/traverse_graph.py
CHANGED
@@ -1,15 +1,10 @@
|
|
1 |
import asyncio
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
-
from graphgen.models import
|
7 |
-
JsonKVStorage,
|
8 |
-
NetworkXStorage,
|
9 |
-
OpenAIClient,
|
10 |
-
Tokenizer,
|
11 |
-
TraverseStrategy,
|
12 |
-
)
|
13 |
from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
|
14 |
from graphgen.templates import (
|
15 |
ANSWER_REPHRASING_PROMPT,
|
@@ -164,7 +159,7 @@ async def traverse_graph_for_aggregated(
|
|
164 |
llm_client: OpenAIClient,
|
165 |
tokenizer: Tokenizer,
|
166 |
graph_storage: NetworkXStorage,
|
167 |
-
traverse_strategy:
|
168 |
text_chunks_storage: JsonKVStorage,
|
169 |
progress_bar: gr.Progress = None,
|
170 |
max_concurrent: int = 1000,
|
@@ -240,7 +235,7 @@ async def traverse_graph_for_aggregated(
|
|
240 |
"question": question,
|
241 |
"answer": context,
|
242 |
"loss": get_average_loss(
|
243 |
-
_process_batch, traverse_strategy
|
244 |
),
|
245 |
}
|
246 |
}
|
@@ -272,7 +267,7 @@ async def traverse_graph_for_aggregated(
|
|
272 |
"question": qa["question"],
|
273 |
"answer": qa["answer"],
|
274 |
"loss": get_average_loss(
|
275 |
-
_process_batch, traverse_strategy
|
276 |
),
|
277 |
}
|
278 |
return final_results
|
@@ -313,7 +308,7 @@ async def traverse_graph_for_atomic(
|
|
313 |
llm_client: OpenAIClient,
|
314 |
tokenizer: Tokenizer,
|
315 |
graph_storage: NetworkXStorage,
|
316 |
-
traverse_strategy:
|
317 |
text_chunks_storage: JsonKVStorage,
|
318 |
progress_bar: gr.Progress = None,
|
319 |
max_concurrent: int = 1000,
|
@@ -331,7 +326,6 @@ async def traverse_graph_for_atomic(
|
|
331 |
:return: question and answer
|
332 |
"""
|
333 |
|
334 |
-
assert traverse_strategy.qa_form == "atomic"
|
335 |
semaphore = asyncio.Semaphore(max_concurrent)
|
336 |
|
337 |
def _parse_qa(qa: str) -> tuple:
|
@@ -429,7 +423,7 @@ async def traverse_graph_for_multi_hop(
|
|
429 |
llm_client: OpenAIClient,
|
430 |
tokenizer: Tokenizer,
|
431 |
graph_storage: NetworkXStorage,
|
432 |
-
traverse_strategy:
|
433 |
text_chunks_storage: JsonKVStorage,
|
434 |
progress_bar: gr.Progress = None,
|
435 |
max_concurrent: int = 1000,
|
@@ -517,7 +511,7 @@ async def traverse_graph_for_multi_hop(
|
|
517 |
"question": question,
|
518 |
"answer": answer,
|
519 |
"loss": get_average_loss(
|
520 |
-
_process_batch, traverse_strategy
|
521 |
),
|
522 |
}
|
523 |
}
|
|
|
1 |
import asyncio
|
2 |
+
from typing import Dict
|
3 |
|
4 |
import gradio as gr
|
5 |
from tqdm.asyncio import tqdm as tqdm_async
|
6 |
|
7 |
+
from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
|
9 |
from graphgen.templates import (
|
10 |
ANSWER_REPHRASING_PROMPT,
|
|
|
159 |
llm_client: OpenAIClient,
|
160 |
tokenizer: Tokenizer,
|
161 |
graph_storage: NetworkXStorage,
|
162 |
+
traverse_strategy: Dict,
|
163 |
text_chunks_storage: JsonKVStorage,
|
164 |
progress_bar: gr.Progress = None,
|
165 |
max_concurrent: int = 1000,
|
|
|
235 |
"question": question,
|
236 |
"answer": context,
|
237 |
"loss": get_average_loss(
|
238 |
+
_process_batch, traverse_strategy["loss_strategy"]
|
239 |
),
|
240 |
}
|
241 |
}
|
|
|
267 |
"question": qa["question"],
|
268 |
"answer": qa["answer"],
|
269 |
"loss": get_average_loss(
|
270 |
+
_process_batch, traverse_strategy["loss_strategy"]
|
271 |
),
|
272 |
}
|
273 |
return final_results
|
|
|
308 |
llm_client: OpenAIClient,
|
309 |
tokenizer: Tokenizer,
|
310 |
graph_storage: NetworkXStorage,
|
311 |
+
traverse_strategy: Dict,
|
312 |
text_chunks_storage: JsonKVStorage,
|
313 |
progress_bar: gr.Progress = None,
|
314 |
max_concurrent: int = 1000,
|
|
|
326 |
:return: question and answer
|
327 |
"""
|
328 |
|
|
|
329 |
semaphore = asyncio.Semaphore(max_concurrent)
|
330 |
|
331 |
def _parse_qa(qa: str) -> tuple:
|
|
|
423 |
llm_client: OpenAIClient,
|
424 |
tokenizer: Tokenizer,
|
425 |
graph_storage: NetworkXStorage,
|
426 |
+
traverse_strategy: Dict,
|
427 |
text_chunks_storage: JsonKVStorage,
|
428 |
progress_bar: gr.Progress = None,
|
429 |
max_concurrent: int = 1000,
|
|
|
511 |
"question": question,
|
512 |
"answer": answer,
|
513 |
"loss": get_average_loss(
|
514 |
+
_process_batch, traverse_strategy["loss_strategy"]
|
515 |
),
|
516 |
}
|
517 |
}
|
webui/app.py
CHANGED
@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
39 |
set_logger(log_file, if_stream=True)
|
40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
graph_gen.synthesizer_llm_client = OpenAIClient(
|
45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
48 |
request_limit=True,
|
49 |
rpm=RPM(env.get("RPM", 1000)),
|
50 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
51 |
)
|
52 |
-
|
53 |
-
graph_gen.trainee_llm_client = OpenAIClient(
|
54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
57 |
request_limit=True,
|
58 |
rpm=RPM(env.get("RPM", 1000)),
|
59 |
tpm=TPM(env.get("TPM", 50000)),
|
|
|
60 |
)
|
61 |
|
62 |
-
graph_gen
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
return graph_gen
|
65 |
|
@@ -78,27 +83,32 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
78 |
"chunk_size": params.chunk_size,
|
79 |
"chunk_overlap": params.chunk_overlap,
|
80 |
},
|
81 |
-
"output_data_type": params.output_data_type,
|
82 |
-
"output_data_format": params.output_data_format,
|
83 |
-
"tokenizer": params.tokenizer,
|
84 |
"search": {"enabled": False},
|
85 |
-
"
|
86 |
"enabled": params.if_trainee_model,
|
87 |
"quiz_samples": params.quiz_samples,
|
88 |
},
|
89 |
-
"
|
90 |
-
"
|
91 |
-
"
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
},
|
99 |
}
|
100 |
|
101 |
env = {
|
|
|
102 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
103 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
104 |
"TRAINEE_BASE_URL": params.trainee_url,
|
@@ -128,19 +138,18 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
128 |
|
129 |
try:
|
130 |
# Process the data
|
131 |
-
graph_gen.insert()
|
132 |
|
133 |
if config["if_trainee_model"]:
|
134 |
-
#
|
135 |
-
graph_gen.
|
136 |
-
|
137 |
-
# Judge statements
|
138 |
-
graph_gen.judge()
|
139 |
else:
|
140 |
-
|
141 |
|
142 |
-
|
143 |
-
|
|
|
|
|
144 |
|
145 |
# Save output
|
146 |
output_data = graph_gen.qa_storage.data
|
@@ -249,6 +258,9 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
249 |
)
|
250 |
|
251 |
with gr.Accordion(label=_("Model Config"), open=False):
|
|
|
|
|
|
|
252 |
synthesizer_url = gr.Textbox(
|
253 |
label="Synthesizer URL",
|
254 |
value="https://api.siliconflow.cn/v1",
|
@@ -300,9 +312,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
300 |
step=100,
|
301 |
interactive=True,
|
302 |
)
|
303 |
-
tokenizer = gr.Textbox(
|
304 |
-
label="Tokenizer", value="cl100k_base", interactive=True
|
305 |
-
)
|
306 |
output_data_type = gr.Radio(
|
307 |
choices=["atomic", "multi_hop", "aggregated"],
|
308 |
label="Output Data Type",
|
|
|
39 |
set_logger(log_file, if_stream=True)
|
40 |
os.environ.update({k: str(v) for k, v in env.items()})
|
41 |
|
42 |
+
tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
43 |
+
synthesizer_llm_client = OpenAIClient(
|
|
|
44 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
45 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
46 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
47 |
request_limit=True,
|
48 |
rpm=RPM(env.get("RPM", 1000)),
|
49 |
tpm=TPM(env.get("TPM", 50000)),
|
50 |
+
tokenizer=tokenizer_instance,
|
51 |
)
|
52 |
+
trainee_llm_client = OpenAIClient(
|
|
|
53 |
model_name=env.get("TRAINEE_MODEL", ""),
|
54 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
55 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
56 |
request_limit=True,
|
57 |
rpm=RPM(env.get("RPM", 1000)),
|
58 |
tpm=TPM(env.get("TPM", 50000)),
|
59 |
+
tokenizer=tokenizer_instance,
|
60 |
)
|
61 |
|
62 |
+
graph_gen = GraphGen(
|
63 |
+
working_dir=working_dir,
|
64 |
+
tokenizer_instance=tokenizer_instance,
|
65 |
+
synthesizer_llm_client=synthesizer_llm_client,
|
66 |
+
trainee_llm_client=trainee_llm_client,
|
67 |
+
)
|
68 |
|
69 |
return graph_gen
|
70 |
|
|
|
83 |
"chunk_size": params.chunk_size,
|
84 |
"chunk_overlap": params.chunk_overlap,
|
85 |
},
|
|
|
|
|
|
|
86 |
"search": {"enabled": False},
|
87 |
+
"quiz_and_judge": {
|
88 |
"enabled": params.if_trainee_model,
|
89 |
"quiz_samples": params.quiz_samples,
|
90 |
},
|
91 |
+
"partition": {
|
92 |
+
"method": "ece",
|
93 |
+
"method_params": {
|
94 |
+
"bidirectional": params.bidirectional,
|
95 |
+
"expand_method": params.expand_method,
|
96 |
+
"max_extra_edges": params.max_extra_edges,
|
97 |
+
"max_tokens": params.max_tokens,
|
98 |
+
"max_depth": params.max_depth,
|
99 |
+
"edge_sampling": params.edge_sampling,
|
100 |
+
"isolated_node_strategy": params.isolated_node_strategy,
|
101 |
+
"loss_strategy": params.loss_strategy,
|
102 |
+
},
|
103 |
+
},
|
104 |
+
"generate": {
|
105 |
+
"mode": params.output_data_type,
|
106 |
+
"data_format": params.output_data_format,
|
107 |
},
|
108 |
}
|
109 |
|
110 |
env = {
|
111 |
+
"TOKENIZER_MODEL": params.tokenizer,
|
112 |
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
113 |
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
114 |
"TRAINEE_BASE_URL": params.trainee_url,
|
|
|
138 |
|
139 |
try:
|
140 |
# Process the data
|
141 |
+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
|
142 |
|
143 |
if config["if_trainee_model"]:
|
144 |
+
# Quiz and Judge
|
145 |
+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
|
|
|
|
|
|
|
146 |
else:
|
147 |
+
config["partition"]["method_params"]["edge_sampling"] = "random"
|
148 |
|
149 |
+
graph_gen.generate(
|
150 |
+
partition_config=config["partition"],
|
151 |
+
generate_config=config["generate"],
|
152 |
+
)
|
153 |
|
154 |
# Save output
|
155 |
output_data = graph_gen.qa_storage.data
|
|
|
258 |
)
|
259 |
|
260 |
with gr.Accordion(label=_("Model Config"), open=False):
|
261 |
+
tokenizer = gr.Textbox(
|
262 |
+
label="Tokenizer", value="cl100k_base", interactive=True
|
263 |
+
)
|
264 |
synthesizer_url = gr.Textbox(
|
265 |
label="Synthesizer URL",
|
266 |
value="https://api.siliconflow.cn/v1",
|
|
|
312 |
step=100,
|
313 |
interactive=True,
|
314 |
)
|
|
|
|
|
|
|
315 |
output_data_type = gr.Radio(
|
316 |
choices=["atomic", "multi_hop", "aggregated"],
|
317 |
label="Output Data Type",
|