github-actions[bot] commited on
Commit
3a3b216
·
1 Parent(s): bda6eda

Auto-sync from demo at Tue Sep 30 03:30:14 UTC 2025

Browse files
Files changed (41) hide show
  1. app.py +3 -3
  2. graphgen/bases/__init__.py +12 -0
  3. graphgen/bases/base_kg_builder.py +41 -0
  4. graphgen/bases/base_llm_client.py +74 -0
  5. graphgen/bases/base_tokenizer.py +44 -0
  6. graphgen/bases/datatypes.py +14 -0
  7. graphgen/graphgen.py +53 -98
  8. graphgen/models/__init__.py +5 -5
  9. graphgen/models/evaluate/length_evaluator.py +2 -2
  10. graphgen/models/kg_builder/NetworkXKGBuilder.py +18 -0
  11. graphgen/{operators/kg → models/kg_builder}/__init__.py +0 -0
  12. graphgen/models/llm/ollama_client.py +21 -0
  13. graphgen/models/llm/{openai_model.py → openai_client.py} +43 -40
  14. graphgen/models/llm/tokenizer.py +0 -73
  15. graphgen/models/llm/topk_token_model.py +9 -16
  16. graphgen/models/reader/__init__.py +0 -18
  17. graphgen/models/splitter/__init__.py +0 -27
  18. graphgen/models/tokenizer/__init__.py +51 -0
  19. graphgen/models/tokenizer/hf_tokenizer.py +18 -0
  20. graphgen/models/tokenizer/tiktoken_tokenizer.py +18 -0
  21. graphgen/operators/__init__.py +3 -12
  22. graphgen/operators/build_kg/__init__.py +0 -0
  23. graphgen/operators/build_kg/extract_kg.py +127 -0
  24. graphgen/operators/{kg → build_kg}/merge_kg.py +7 -7
  25. graphgen/operators/{kg → build_kg}/split_kg.py +0 -0
  26. graphgen/operators/generate/generate_cot.py +2 -2
  27. graphgen/operators/judge.py +2 -2
  28. graphgen/operators/kg/extract_kg.py +0 -152
  29. graphgen/operators/preprocess/resolute_coreference.py +2 -2
  30. graphgen/operators/quiz.py +48 -35
  31. graphgen/operators/read/__init__.py +1 -0
  32. graphgen/operators/read/read_files.py +19 -0
  33. graphgen/operators/split/__init__.py +1 -0
  34. graphgen/operators/split/split_chunks.py +76 -0
  35. graphgen/operators/traverse_graph.py +7 -7
  36. graphgen/utils/__init__.py +2 -0
  37. graphgen/utils/calculate_confidence.py +12 -2
  38. graphgen/utils/run_concurrent.py +38 -0
  39. graphgen/utils/wrap.py +13 -0
  40. webui/app.py +3 -3
  41. webui/utils/count_tokens.py +1 -1
app.py CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
9
  from dotenv import load_dotenv
10
 
11
  from graphgen.graphgen import GraphGen
12
- from graphgen.models import OpenAIModel, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
  from webui.base import WebuiParams
@@ -41,7 +41,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
41
 
42
  graph_gen = GraphGen(working_dir=working_dir, config=config)
43
  # Set up LLM clients
44
- graph_gen.synthesizer_llm_client = OpenAIModel(
45
  model_name=env.get("SYNTHESIZER_MODEL", ""),
46
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
47
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
@@ -50,7 +50,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
50
  tpm=TPM(env.get("TPM", 50000)),
51
  )
52
 
53
- graph_gen.trainee_llm_client = OpenAIModel(
54
  model_name=env.get("TRAINEE_MODEL", ""),
55
  base_url=env.get("TRAINEE_BASE_URL", ""),
56
  api_key=env.get("TRAINEE_API_KEY", ""),
 
9
  from dotenv import load_dotenv
10
 
11
  from graphgen.graphgen import GraphGen
12
+ from graphgen.models import OpenAIClient, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
  from webui.base import WebuiParams
 
41
 
42
  graph_gen = GraphGen(working_dir=working_dir, config=config)
43
  # Set up LLM clients
44
+ graph_gen.synthesizer_llm_client = OpenAIClient(
45
  model_name=env.get("SYNTHESIZER_MODEL", ""),
46
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
47
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
 
50
  tpm=TPM(env.get("TPM", 50000)),
51
  )
52
 
53
+ graph_gen.trainee_llm_client = OpenAIClient(
54
  model_name=env.get("TRAINEE_MODEL", ""),
55
  base_url=env.get("TRAINEE_BASE_URL", ""),
56
  api_key=env.get("TRAINEE_API_KEY", ""),
graphgen/bases/__init__.py CHANGED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_kg_builder import BaseKGBuilder
2
+ from .base_llm_client import BaseLLMClient
3
+ from .base_reader import BaseReader
4
+ from .base_splitter import BaseSplitter
5
+ from .base_storage import (
6
+ BaseGraphStorage,
7
+ BaseKVStorage,
8
+ BaseListStorage,
9
+ StorageNameSpace,
10
+ )
11
+ from .base_tokenizer import BaseTokenizer
12
+ from .datatypes import Chunk, QAPair, Token
graphgen/bases/base_kg_builder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, Tuple
5
+
6
+ from graphgen.bases.base_llm_client import BaseLLMClient
7
+ from graphgen.bases.base_storage import BaseGraphStorage
8
+ from graphgen.bases.datatypes import Chunk
9
+
10
+
11
+ @dataclass
12
+ class BaseKGBuilder(ABC):
13
+ kg_instance: BaseGraphStorage
14
+ llm_client: BaseLLMClient
15
+
16
+ _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
17
+ _edges: Dict[Tuple[str, str], List[dict]] = field(
18
+ default_factory=lambda: defaultdict(list)
19
+ )
20
+
21
+ def build(self, chunks: List[Chunk]) -> None:
22
+ pass
23
+
24
+ @abstractmethod
25
+ async def extract_all(self, chunks: List[Chunk]) -> None:
26
+ """Extract nodes and edges from all chunks."""
27
+ raise NotImplementedError
28
+
29
+ @abstractmethod
30
+ async def extract(
31
+ self, chunk: Chunk
32
+ ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
33
+ """Extract nodes and edges from a single chunk."""
34
+ raise NotImplementedError
35
+
36
+ @abstractmethod
37
+ async def merge_nodes(
38
+ self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm
39
+ ) -> None:
40
+ """Merge extracted nodes into the knowledge graph."""
41
+ raise NotImplementedError
graphgen/bases/base_llm_client.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import re
5
+ from typing import Any, List, Optional
6
+
7
+ from graphgen.bases.base_tokenizer import BaseTokenizer
8
+ from graphgen.bases.datatypes import Token
9
+
10
+
11
+ class BaseLLMClient(abc.ABC):
12
+ """
13
+ LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ *,
19
+ system_prompt: str = "",
20
+ temperature: float = 0.0,
21
+ max_tokens: int = 4096,
22
+ repetition_penalty: float = 1.05,
23
+ top_p: float = 0.95,
24
+ top_k: int = 50,
25
+ tokenizer: Optional[BaseTokenizer] = None,
26
+ **kwargs: Any,
27
+ ):
28
+ self.system_prompt = system_prompt
29
+ self.temperature = temperature
30
+ self.max_tokens = max_tokens
31
+ self.repetition_penalty = repetition_penalty
32
+ self.top_p = top_p
33
+ self.top_k = top_k
34
+ self.tokenizer = tokenizer
35
+
36
+ for k, v in kwargs.items():
37
+ setattr(self, k, v)
38
+
39
+ @abc.abstractmethod
40
+ async def generate_answer(
41
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
42
+ ) -> str:
43
+ """Generate answer from the model."""
44
+ raise NotImplementedError
45
+
46
+ @abc.abstractmethod
47
+ async def generate_topk_per_token(
48
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
49
+ ) -> List[Token]:
50
+ """Generate top-k tokens for the next token prediction."""
51
+ raise NotImplementedError
52
+
53
+ @abc.abstractmethod
54
+ async def generate_inputs_prob(
55
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
56
+ ) -> List[Token]:
57
+ """Generate probabilities for each token in the input."""
58
+ raise NotImplementedError
59
+
60
+ def count_tokens(self, text: str) -> int:
61
+ """Count the number of tokens in the text."""
62
+ if self.tokenizer is None:
63
+ raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.")
64
+ return len(self.tokenizer.encode(text))
65
+
66
+ @staticmethod
67
+ def filter_think_tags(text: str, think_tag: str = "think") -> str:
68
+ """
69
+ Remove <think> tags from the text.
70
+ If the text contains <think> and </think>, it removes everything between them and the tags themselves.
71
+ """
72
+ think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
73
+ filtered_text = think_pattern.sub("", text).strip()
74
+ return filtered_text if filtered_text else text.strip()
graphgen/bases/base_tokenizer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import List
6
+
7
+
8
+ @dataclass
9
+ class BaseTokenizer(ABC):
10
+ model_name: str = "cl100k_base"
11
+
12
+ @abstractmethod
13
+ def encode(self, text: str) -> List[int]:
14
+ """Encode text -> token ids."""
15
+ raise NotImplementedError
16
+
17
+ @abstractmethod
18
+ def decode(self, token_ids: List[int]) -> str:
19
+ """Decode token ids -> text."""
20
+ raise NotImplementedError
21
+
22
+ def count_tokens(self, text: str) -> int:
23
+ return len(self.encode(text))
24
+
25
+ def chunk_by_token_size(
26
+ self,
27
+ content: str,
28
+ *,
29
+ overlap_token_size: int = 128,
30
+ max_token_size: int = 1024,
31
+ ) -> List[dict]:
32
+ tokens = self.encode(content)
33
+ results = []
34
+ step = max_token_size - overlap_token_size
35
+ for index, start in enumerate(range(0, len(tokens), step)):
36
+ chunk_ids = tokens[start : start + max_token_size]
37
+ results.append(
38
+ {
39
+ "tokens": len(chunk_ids),
40
+ "content": self.decode(chunk_ids).strip(),
41
+ "chunk_order_index": index,
42
+ }
43
+ )
44
+ return results
graphgen/bases/datatypes.py CHANGED
@@ -1,4 +1,6 @@
 
1
  from dataclasses import dataclass, field
 
2
 
3
 
4
  @dataclass
@@ -16,3 +18,15 @@ class QAPair:
16
 
17
  question: str
18
  answer: str
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
  from dataclasses import dataclass, field
3
+ from typing import List, Union
4
 
5
 
6
  @dataclass
 
18
 
19
  question: str
20
  answer: str
21
+
22
+
23
+ @dataclass
24
+ class Token:
25
+ text: str
26
+ prob: float
27
+ top_candidates: List = field(default_factory=list)
28
+ ppl: Union[float, None] = field(default=None)
29
+
30
+ @property
31
+ def logprob(self) -> float:
32
+ return math.log(self.prob)
graphgen/graphgen.py CHANGED
@@ -2,10 +2,9 @@ import asyncio
2
  import os
3
  import time
4
  from dataclasses import dataclass, field
5
- from typing import Dict, List, Union, cast
6
 
7
  import gradio as gr
8
- from tqdm.asyncio import tqdm as tqdm_async
9
 
10
  from graphgen.bases.base_storage import StorageNameSpace
11
  from graphgen.bases.datatypes import Chunk
@@ -13,27 +12,25 @@ from graphgen.models import (
13
  JsonKVStorage,
14
  JsonListStorage,
15
  NetworkXStorage,
16
- OpenAIModel,
17
  Tokenizer,
18
  TraverseStrategy,
19
- read_file,
20
- split_chunks,
21
  )
22
-
23
- from .operators import (
24
  extract_kg,
25
  generate_cot,
26
  judge_statement,
27
  quiz,
 
28
  search_all,
29
  traverse_graph_for_aggregated,
30
  traverse_graph_for_atomic,
31
  traverse_graph_for_multi_hop,
32
  )
33
- from .utils import (
 
34
  compute_content_hash,
35
- create_event_loop,
36
- detect_main_language,
37
  format_generation_results,
38
  logger,
39
  )
@@ -49,8 +46,8 @@ class GraphGen:
49
 
50
  # llm
51
  tokenizer_instance: Tokenizer = None
52
- synthesizer_llm_client: OpenAIModel = None
53
- trainee_llm_client: OpenAIModel = None
54
 
55
  # search
56
  search_config: dict = field(
@@ -67,17 +64,17 @@ class GraphGen:
67
  self.tokenizer_instance: Tokenizer = Tokenizer(
68
  model_name=self.config["tokenizer"]
69
  )
70
- self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
71
  model_name=os.getenv("SYNTHESIZER_MODEL"),
72
  api_key=os.getenv("SYNTHESIZER_API_KEY"),
73
  base_url=os.getenv("SYNTHESIZER_BASE_URL"),
74
- tokenizer_instance=self.tokenizer_instance,
75
  )
76
- self.trainee_llm_client: OpenAIModel = OpenAIModel(
77
  model_name=os.getenv("TRAINEE_MODEL"),
78
  api_key=os.getenv("TRAINEE_API_KEY"),
79
  base_url=os.getenv("TRAINEE_BASE_URL"),
80
- tokenizer_instance=self.tokenizer_instance,
81
  )
82
  self.search_config = self.config["search"]
83
 
@@ -111,15 +108,23 @@ class GraphGen:
111
  namespace="qa",
112
  )
113
 
114
- async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict:
115
- # TODO: configurable whether to use coreference resolution
 
 
 
 
 
 
 
116
  if len(data) == 0:
117
- return {}
 
118
 
119
- inserting_chunks = {}
120
- assert isinstance(data, list) and isinstance(data[0], dict)
121
 
122
- # compute hash for each document
 
123
  new_docs = {
124
  compute_content_hash(doc["content"], prefix="doc-"): {
125
  "content": doc["content"]
@@ -128,38 +133,19 @@ class GraphGen:
128
  }
129
  _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
130
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
 
131
  if len(new_docs) == 0:
132
  logger.warning("All docs are already in the storage")
133
- return {}
134
  logger.info("[New Docs] inserting %d docs", len(new_docs))
135
 
136
- cur_index = 1
137
- doc_number = len(new_docs)
138
- async for doc_key, doc in tqdm_async(
139
- new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
140
- ):
141
- doc_language = detect_main_language(doc["content"])
142
- text_chunks = split_chunks(
143
- doc["content"],
144
- language=doc_language,
145
- chunk_size=self.config["split"]["chunk_size"],
146
- chunk_overlap=self.config["split"]["chunk_overlap"],
147
- )
148
-
149
- chunks = {
150
- compute_content_hash(txt, prefix="chunk-"): {
151
- "content": txt,
152
- "full_doc_id": doc_key,
153
- "length": len(self.tokenizer_instance.encode_string(txt)),
154
- "language": doc_language,
155
- }
156
- for txt in text_chunks
157
- }
158
- inserting_chunks.update(chunks)
159
-
160
- if self.progress_bar is not None:
161
- self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
162
- cur_index += 1
163
 
164
  _add_chunk_keys = await self.text_chunks_storage.filter_keys(
165
  list(inserting_chunks.keys())
@@ -167,29 +153,16 @@ class GraphGen:
167
  inserting_chunks = {
168
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
169
  }
170
- await self.full_docs_storage.upsert(new_docs)
171
- await self.text_chunks_storage.upsert(inserting_chunks)
172
-
173
- return inserting_chunks
174
-
175
- def insert(self):
176
- loop = create_event_loop()
177
- loop.run_until_complete(self.async_insert())
178
-
179
- async def async_insert(self):
180
- """
181
- insert chunks into the graph
182
- """
183
-
184
- input_file = self.config["read"]["input_file"]
185
- data = read_file(input_file)
186
- inserting_chunks = await self.async_split_chunks(data)
187
 
188
  if len(inserting_chunks) == 0:
189
  logger.warning("All chunks are already in the storage")
190
  return
 
191
  logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
 
 
192
 
 
193
  logger.info("[Entity and Relation Extraction]...")
194
  _add_entities_and_relations = await extract_kg(
195
  llm_client=self.synthesizer_llm_client,
@@ -219,11 +192,8 @@ class GraphGen:
219
  tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
220
  await asyncio.gather(*tasks)
221
 
222
- def search(self):
223
- loop = create_event_loop()
224
- loop.run_until_complete(self.async_search())
225
-
226
- async def async_search(self):
227
  logger.info(
228
  "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
229
  )
@@ -257,13 +227,10 @@ class GraphGen:
257
  ]
258
  )
259
  # TODO: fix insert after search
260
- await self.async_insert()
261
-
262
- def quiz(self):
263
- loop = create_event_loop()
264
- loop.run_until_complete(self.async_quiz())
265
 
266
- async def async_quiz(self):
 
267
  max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
268
  await quiz(
269
  self.synthesizer_llm_client,
@@ -273,11 +240,8 @@ class GraphGen:
273
  )
274
  await self.rephrase_storage.index_done_callback()
275
 
276
- def judge(self):
277
- loop = create_event_loop()
278
- loop.run_until_complete(self.async_judge())
279
-
280
- async def async_judge(self):
281
  re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
282
  _update_relations = await judge_statement(
283
  self.trainee_llm_client,
@@ -287,11 +251,8 @@ class GraphGen:
287
  )
288
  await _update_relations.index_done_callback()
289
 
290
- def traverse(self):
291
- loop = create_event_loop()
292
- loop.run_until_complete(self.async_traverse())
293
-
294
- async def async_traverse(self):
295
  output_data_type = self.config["output_data_type"]
296
 
297
  if output_data_type == "atomic":
@@ -331,11 +292,8 @@ class GraphGen:
331
  await self.qa_storage.upsert(results)
332
  await self.qa_storage.index_done_callback()
333
 
334
- def generate_reasoning(self, method_params):
335
- loop = create_event_loop()
336
- loop.run_until_complete(self.async_generate_reasoning(method_params))
337
-
338
- async def async_generate_reasoning(self, method_params):
339
  results = await generate_cot(
340
  self.graph_storage,
341
  self.synthesizer_llm_client,
@@ -349,11 +307,8 @@ class GraphGen:
349
  await self.qa_storage.upsert(results)
350
  await self.qa_storage.index_done_callback()
351
 
352
- def clear(self):
353
- loop = create_event_loop()
354
- loop.run_until_complete(self.async_clear())
355
-
356
- async def async_clear(self):
357
  await self.full_docs_storage.drop()
358
  await self.text_chunks_storage.drop()
359
  await self.search_storage.drop()
 
2
  import os
3
  import time
4
  from dataclasses import dataclass, field
5
+ from typing import Dict, cast
6
 
7
  import gradio as gr
 
8
 
9
  from graphgen.bases.base_storage import StorageNameSpace
10
  from graphgen.bases.datatypes import Chunk
 
12
  JsonKVStorage,
13
  JsonListStorage,
14
  NetworkXStorage,
15
+ OpenAIClient,
16
  Tokenizer,
17
  TraverseStrategy,
 
 
18
  )
19
+ from graphgen.operators import (
20
+ chunk_documents,
21
  extract_kg,
22
  generate_cot,
23
  judge_statement,
24
  quiz,
25
+ read_files,
26
  search_all,
27
  traverse_graph_for_aggregated,
28
  traverse_graph_for_atomic,
29
  traverse_graph_for_multi_hop,
30
  )
31
+ from graphgen.utils import (
32
+ async_to_sync_method,
33
  compute_content_hash,
 
 
34
  format_generation_results,
35
  logger,
36
  )
 
46
 
47
  # llm
48
  tokenizer_instance: Tokenizer = None
49
+ synthesizer_llm_client: OpenAIClient = None
50
+ trainee_llm_client: OpenAIClient = None
51
 
52
  # search
53
  search_config: dict = field(
 
64
  self.tokenizer_instance: Tokenizer = Tokenizer(
65
  model_name=self.config["tokenizer"]
66
  )
67
+ self.synthesizer_llm_client: OpenAIClient = OpenAIClient(
68
  model_name=os.getenv("SYNTHESIZER_MODEL"),
69
  api_key=os.getenv("SYNTHESIZER_API_KEY"),
70
  base_url=os.getenv("SYNTHESIZER_BASE_URL"),
71
+ tokenizer=self.tokenizer_instance,
72
  )
73
+ self.trainee_llm_client: OpenAIClient = OpenAIClient(
74
  model_name=os.getenv("TRAINEE_MODEL"),
75
  api_key=os.getenv("TRAINEE_API_KEY"),
76
  base_url=os.getenv("TRAINEE_BASE_URL"),
77
+ tokenizer=self.tokenizer_instance,
78
  )
79
  self.search_config = self.config["search"]
80
 
 
108
  namespace="qa",
109
  )
110
 
111
+ @async_to_sync_method
112
+ async def insert(self):
113
+ """
114
+ insert chunks into the graph
115
+ """
116
+ input_file = self.config["read"]["input_file"]
117
+
118
+ # Step 1: Read files
119
+ data = read_files(input_file)
120
  if len(data) == 0:
121
+ logger.warning("No data to process")
122
+ return
123
 
124
+ # TODO: configurable whether to use coreference resolution
 
125
 
126
+ # Step 2: Split chunks and filter existing ones
127
+ assert isinstance(data, list) and isinstance(data[0], dict)
128
  new_docs = {
129
  compute_content_hash(doc["content"], prefix="doc-"): {
130
  "content": doc["content"]
 
133
  }
134
  _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
135
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
136
+
137
  if len(new_docs) == 0:
138
  logger.warning("All docs are already in the storage")
139
+ return
140
  logger.info("[New Docs] inserting %d docs", len(new_docs))
141
 
142
+ inserting_chunks = await chunk_documents(
143
+ new_docs,
144
+ self.config["split"]["chunk_size"],
145
+ self.config["split"]["chunk_overlap"],
146
+ self.tokenizer_instance,
147
+ self.progress_bar,
148
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  _add_chunk_keys = await self.text_chunks_storage.filter_keys(
151
  list(inserting_chunks.keys())
 
153
  inserting_chunks = {
154
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
155
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  if len(inserting_chunks) == 0:
158
  logger.warning("All chunks are already in the storage")
159
  return
160
+
161
  logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
162
+ await self.full_docs_storage.upsert(new_docs)
163
+ await self.text_chunks_storage.upsert(inserting_chunks)
164
 
165
+ # Step 3: Extract entities and relations from chunks
166
  logger.info("[Entity and Relation Extraction]...")
167
  _add_entities_and_relations = await extract_kg(
168
  llm_client=self.synthesizer_llm_client,
 
192
  tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
193
  await asyncio.gather(*tasks)
194
 
195
+ @async_to_sync_method
196
+ async def search(self):
 
 
 
197
  logger.info(
198
  "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
199
  )
 
227
  ]
228
  )
229
  # TODO: fix insert after search
230
+ await self.insert()
 
 
 
 
231
 
232
+ @async_to_sync_method
233
+ async def quiz(self):
234
  max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
235
  await quiz(
236
  self.synthesizer_llm_client,
 
240
  )
241
  await self.rephrase_storage.index_done_callback()
242
 
243
+ @async_to_sync_method
244
+ async def judge(self):
 
 
 
245
  re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
246
  _update_relations = await judge_statement(
247
  self.trainee_llm_client,
 
251
  )
252
  await _update_relations.index_done_callback()
253
 
254
+ @async_to_sync_method
255
+ async def traverse(self):
 
 
 
256
  output_data_type = self.config["output_data_type"]
257
 
258
  if output_data_type == "atomic":
 
292
  await self.qa_storage.upsert(results)
293
  await self.qa_storage.index_done_callback()
294
 
295
+ @async_to_sync_method
296
+ async def generate_reasoning(self, method_params):
 
 
 
297
  results = await generate_cot(
298
  self.graph_storage,
299
  self.synthesizer_llm_client,
 
307
  await self.qa_storage.upsert(results)
308
  await self.qa_storage.index_done_callback()
309
 
310
+ @async_to_sync_method
311
+ async def clear(self):
 
 
 
312
  await self.full_docs_storage.drop()
313
  await self.text_chunks_storage.drop()
314
  await self.search_storage.drop()
graphgen/models/__init__.py CHANGED
@@ -3,15 +3,15 @@ from .evaluate.length_evaluator import LengthEvaluator
3
  from .evaluate.mtld_evaluator import MTLDEvaluator
4
  from .evaluate.reward_evaluator import RewardEvaluator
5
  from .evaluate.uni_evaluator import UniEvaluator
6
- from .llm.openai_model import OpenAIModel
7
- from .llm.tokenizer import Tokenizer
8
- from .llm.topk_token_model import Token, TopkTokenModel
9
- from .reader import read_file
10
  from .search.db.uniprot_search import UniProtSearch
11
  from .search.kg.wiki_search import WikiSearch
12
  from .search.web.bing_search import BingSearch
13
  from .search.web.google_search import GoogleSearch
14
- from .splitter import split_chunks
15
  from .storage.json_storage import JsonKVStorage, JsonListStorage
16
  from .storage.networkx_storage import NetworkXStorage
17
  from .strategy.travserse_strategy import TraverseStrategy
 
 
3
  from .evaluate.mtld_evaluator import MTLDEvaluator
4
  from .evaluate.reward_evaluator import RewardEvaluator
5
  from .evaluate.uni_evaluator import UniEvaluator
6
+ from .llm.openai_client import OpenAIClient
7
+ from .llm.topk_token_model import TopkTokenModel
8
+ from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
 
9
  from .search.db.uniprot_search import UniProtSearch
10
  from .search.kg.wiki_search import WikiSearch
11
  from .search.web.bing_search import BingSearch
12
  from .search.web.google_search import GoogleSearch
13
+ from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
14
  from .storage.json_storage import JsonKVStorage, JsonListStorage
15
  from .storage.networkx_storage import NetworkXStorage
16
  from .strategy.travserse_strategy import TraverseStrategy
17
+ from .tokenizer import Tokenizer
graphgen/models/evaluate/length_evaluator.py CHANGED
@@ -2,7 +2,7 @@ from dataclasses import dataclass
2
 
3
  from graphgen.bases.datatypes import QAPair
4
  from graphgen.models.evaluate.base_evaluator import BaseEvaluator
5
- from graphgen.models.llm.tokenizer import Tokenizer
6
  from graphgen.utils import create_event_loop
7
 
8
 
@@ -18,5 +18,5 @@ class LengthEvaluator(BaseEvaluator):
18
  return await loop.run_in_executor(None, self._calculate_length, pair.answer)
19
 
20
  def _calculate_length(self, text: str) -> float:
21
- tokens = self.tokenizer.encode_string(text)
22
  return len(tokens)
 
2
 
3
  from graphgen.bases.datatypes import QAPair
4
  from graphgen.models.evaluate.base_evaluator import BaseEvaluator
5
+ from graphgen.models.tokenizer import Tokenizer
6
  from graphgen.utils import create_event_loop
7
 
8
 
 
18
  return await loop.run_in_executor(None, self._calculate_length, pair.answer)
19
 
20
  def _calculate_length(self, text: str) -> float:
21
+ tokens = self.tokenizer.encode(text)
22
  return len(tokens)
graphgen/models/kg_builder/NetworkXKGBuilder.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from graphgen.bases import BaseKGBuilder
4
+
5
+
6
+ @dataclass
7
+ class NetworkXKGBuilder(BaseKGBuilder):
8
+ def build(self, chunks):
9
+ pass
10
+
11
+ async def extract_all(self, chunks):
12
+ pass
13
+
14
+ async def extract(self, chunk):
15
+ pass
16
+
17
+ async def merge_nodes(self, nodes_data, kg_instance, llm):
18
+ pass
graphgen/{operators/kg → models/kg_builder}/__init__.py RENAMED
File without changes
graphgen/models/llm/ollama_client.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: implement ollama client
2
+ from typing import Any, List, Optional
3
+
4
+ from graphgen.bases import BaseLLMClient, Token
5
+
6
+
7
+ class OllamaClient(BaseLLMClient):
8
+ async def generate_answer(
9
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
10
+ ) -> str:
11
+ pass
12
+
13
+ async def generate_topk_per_token(
14
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
15
+ ) -> List[Token]:
16
+ pass
17
+
18
+ async def generate_inputs_prob(
19
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
20
+ ) -> List[Token]:
21
+ pass
graphgen/models/llm/{openai_model.py → openai_client.py} RENAMED
@@ -1,7 +1,5 @@
1
  import math
2
- import re
3
- from dataclasses import dataclass, field
4
- from typing import Dict, List, Optional
5
 
6
  import openai
7
  from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
@@ -12,9 +10,9 @@ from tenacity import (
12
  wait_exponential,
13
  )
14
 
 
 
15
  from graphgen.models.llm.limitter import RPM, TPM
16
- from graphgen.models.llm.tokenizer import Tokenizer
17
- from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
18
 
19
 
20
  def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
@@ -30,32 +28,33 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
30
  return tokens
31
 
32
 
33
- def filter_think_tags(text: str) -> str:
34
- """
35
- Remove <think> tags from the text.
36
- If the text contains <think> and </think>, it removes everything between them and the tags themselves.
37
- """
38
- think_pattern = re.compile(r"<think>.*?</think>", re.DOTALL)
39
- filtered_text = think_pattern.sub("", text).strip()
40
- return filtered_text if filtered_text else text.strip()
41
-
42
-
43
- @dataclass
44
- class OpenAIModel(TopkTokenModel):
45
- model_name: str = "gpt-4o-mini"
46
- api_key: str = None
47
- base_url: str = None
48
-
49
- system_prompt: str = ""
50
- json_mode: bool = False
51
- seed: int = None
52
-
53
- token_usage: list = field(default_factory=list)
54
- request_limit: bool = False
55
- rpm: RPM = field(default_factory=lambda: RPM(rpm=1000))
56
- tpm: TPM = field(default_factory=lambda: TPM(tpm=50000))
57
-
58
- tokenizer_instance: Tokenizer = field(default_factory=Tokenizer)
 
59
 
60
  def __post_init__(self):
61
  assert self.api_key is not None, "Please provide api key to access openai api."
@@ -66,7 +65,7 @@ class OpenAIModel(TopkTokenModel):
66
  def _pre_generate(self, text: str, history: List[str]) -> Dict:
67
  kwargs = {
68
  "temperature": self.temperature,
69
- "top_p": self.topp,
70
  "max_tokens": self.max_tokens,
71
  }
72
  if self.seed:
@@ -94,7 +93,10 @@ class OpenAIModel(TopkTokenModel):
94
  ),
95
  )
96
  async def generate_topk_per_token(
97
- self, text: str, history: Optional[List[str]] = None
 
 
 
98
  ) -> List[Token]:
99
  kwargs = self._pre_generate(text, history)
100
  if self.topk_per_token > 0:
@@ -120,16 +122,16 @@ class OpenAIModel(TopkTokenModel):
120
  ),
121
  )
122
  async def generate_answer(
123
- self, text: str, history: Optional[List[str]] = None, temperature: int = 0
 
 
 
124
  ) -> str:
125
  kwargs = self._pre_generate(text, history)
126
- kwargs["temperature"] = temperature
127
 
128
  prompt_tokens = 0
129
  for message in kwargs["messages"]:
130
- prompt_tokens += len(
131
- self.tokenizer_instance.encode_string(message["content"])
132
- )
133
  estimated_tokens = prompt_tokens + kwargs["max_tokens"]
134
 
135
  if self.request_limit:
@@ -147,9 +149,10 @@ class OpenAIModel(TopkTokenModel):
147
  "total_tokens": completion.usage.total_tokens,
148
  }
149
  )
150
- return filter_think_tags(completion.choices[0].message.content)
151
 
152
  async def generate_inputs_prob(
153
- self, text: str, history: Optional[List[str]] = None
154
  ) -> List[Token]:
 
155
  raise NotImplementedError
 
1
  import math
2
+ from typing import Any, Dict, List, Optional
 
 
3
 
4
  import openai
5
  from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
 
10
  wait_exponential,
11
  )
12
 
13
+ from graphgen.bases.base_llm_client import BaseLLMClient
14
+ from graphgen.bases.datatypes import Token
15
  from graphgen.models.llm.limitter import RPM, TPM
 
 
16
 
17
 
18
  def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
 
28
  return tokens
29
 
30
 
31
+ class OpenAIClient(BaseLLMClient):
32
+ def __init__(
33
+ self,
34
+ *,
35
+ model_name: str = "gpt-4o-mini",
36
+ api_key: Optional[str] = None,
37
+ base_url: Optional[str] = None,
38
+ json_mode: bool = False,
39
+ seed: Optional[int] = None,
40
+ topk_per_token: int = 5, # number of topk tokens to generate for each token
41
+ request_limit: bool = False,
42
+ **kwargs: Any,
43
+ ):
44
+ super().__init__(**kwargs)
45
+ self.model_name = model_name
46
+ self.api_key = api_key
47
+ self.base_url = base_url
48
+ self.json_mode = json_mode
49
+ self.seed = seed
50
+ self.topk_per_token = topk_per_token
51
+
52
+ self.token_usage: list = []
53
+ self.request_limit = request_limit
54
+ self.rpm = RPM(rpm=1000)
55
+ self.tpm = TPM(tpm=50000)
56
+
57
+ self.__post_init__()
58
 
59
  def __post_init__(self):
60
  assert self.api_key is not None, "Please provide api key to access openai api."
 
65
  def _pre_generate(self, text: str, history: List[str]) -> Dict:
66
  kwargs = {
67
  "temperature": self.temperature,
68
+ "top_p": self.top_p,
69
  "max_tokens": self.max_tokens,
70
  }
71
  if self.seed:
 
93
  ),
94
  )
95
  async def generate_topk_per_token(
96
+ self,
97
+ text: str,
98
+ history: Optional[List[str]] = None,
99
+ **extra: Any,
100
  ) -> List[Token]:
101
  kwargs = self._pre_generate(text, history)
102
  if self.topk_per_token > 0:
 
122
  ),
123
  )
124
  async def generate_answer(
125
+ self,
126
+ text: str,
127
+ history: Optional[List[str]] = None,
128
+ **extra: Any,
129
  ) -> str:
130
  kwargs = self._pre_generate(text, history)
 
131
 
132
  prompt_tokens = 0
133
  for message in kwargs["messages"]:
134
+ prompt_tokens += len(self.tokenizer.encode(message["content"]))
 
 
135
  estimated_tokens = prompt_tokens + kwargs["max_tokens"]
136
 
137
  if self.request_limit:
 
149
  "total_tokens": completion.usage.total_tokens,
150
  }
151
  )
152
+ return self.filter_think_tags(completion.choices[0].message.content)
153
 
154
  async def generate_inputs_prob(
155
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
156
  ) -> List[Token]:
157
+ """Generate probabilities for each token in the input."""
158
  raise NotImplementedError
graphgen/models/llm/tokenizer.py DELETED
@@ -1,73 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List
3
- import tiktoken
4
-
5
- try:
6
- from transformers import AutoTokenizer
7
- TRANSFORMERS_AVAILABLE = True
8
- except ImportError:
9
- AutoTokenizer = None
10
- TRANSFORMERS_AVAILABLE = False
11
-
12
-
13
- def get_tokenizer(tokenizer_name: str = "cl100k_base"):
14
- """
15
- Get a tokenizer instance by name.
16
-
17
- :param tokenizer_name: tokenizer name, tiktoken encoding name or Hugging Face model name
18
- :return: tokenizer instance
19
- """
20
- if tokenizer_name in tiktoken.list_encoding_names():
21
- return tiktoken.get_encoding(tokenizer_name)
22
- if TRANSFORMERS_AVAILABLE:
23
- try:
24
- return AutoTokenizer.from_pretrained(tokenizer_name)
25
- except Exception as e:
26
- raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e
27
- else:
28
- raise ValueError("Hugging Face Transformers is not available, please install it first.")
29
-
30
- @dataclass
31
- class Tokenizer:
32
- model_name: str = "cl100k_base"
33
-
34
- def __post_init__(self):
35
- self.tokenizer = get_tokenizer(self.model_name)
36
-
37
- def encode_string(self, text: str) -> List[int]:
38
- """
39
- Encode text to tokens
40
-
41
- :param text
42
- :return: tokens
43
- """
44
- return self.tokenizer.encode(text)
45
-
46
- def decode_tokens(self, tokens: List[int]) -> str:
47
- """
48
- Decode tokens to text
49
-
50
- :param tokens
51
- :return: text
52
- """
53
- return self.tokenizer.decode(tokens)
54
-
55
- def chunk_by_token_size(
56
- self, content: str, overlap_token_size=128, max_token_size=1024
57
- ):
58
- tokens = self.encode_string(content)
59
- results = []
60
- for index, start in enumerate(
61
- range(0, len(tokens), max_token_size - overlap_token_size)
62
- ):
63
- chunk_content = self.decode_tokens(
64
- tokens[start : start + max_token_size]
65
- )
66
- results.append(
67
- {
68
- "tokens": min(max_token_size, len(tokens) - start),
69
- "content": chunk_content.strip(),
70
- "chunk_order_index": index,
71
- }
72
- )
73
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/llm/topk_token_model.py CHANGED
@@ -1,18 +1,7 @@
1
- import math
2
- from dataclasses import dataclass, field
3
- from typing import List, Union, Optional
4
 
5
-
6
- @dataclass
7
- class Token:
8
- text: str
9
- prob: float
10
- top_candidates: List = field(default_factory=list)
11
- ppl: Union[float, None] = field(default=None)
12
-
13
- @property
14
- def logprob(self) -> float:
15
- return math.log(self.prob)
16
 
17
 
18
  @dataclass
@@ -34,14 +23,18 @@ class TopkTokenModel:
34
  """
35
  raise NotImplementedError
36
 
37
- async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
 
 
38
  """
39
  Generate prob and text for each token of the input text.
40
  This function is used to visualize the ppl.
41
  """
42
  raise NotImplementedError
43
 
44
- async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str:
 
 
45
  """
46
  Generate answer from the model.
47
  """
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
 
3
 
4
+ from graphgen.bases import Token
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  @dataclass
 
23
  """
24
  raise NotImplementedError
25
 
26
+ async def generate_inputs_prob(
27
+ self, text: str, history: Optional[List[str]] = None
28
+ ) -> List[Token]:
29
  """
30
  Generate prob and text for each token of the input text.
31
  This function is used to visualize the ppl.
32
  """
33
  raise NotImplementedError
34
 
35
+ async def generate_answer(
36
+ self, text: str, history: Optional[List[str]] = None
37
+ ) -> str:
38
  """
39
  Generate answer from the model.
40
  """
graphgen/models/reader/__init__.py CHANGED
@@ -2,21 +2,3 @@ from .csv_reader import CsvReader
2
  from .json_reader import JsonReader
3
  from .jsonl_reader import JsonlReader
4
  from .txt_reader import TxtReader
5
-
6
- _MAPPING = {
7
- "jsonl": JsonlReader,
8
- "json": JsonReader,
9
- "txt": TxtReader,
10
- "csv": CsvReader,
11
- }
12
-
13
-
14
- def read_file(file_path: str):
15
- suffix = file_path.split(".")[-1]
16
- if suffix in _MAPPING:
17
- reader = _MAPPING[suffix]()
18
- else:
19
- raise ValueError(
20
- f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}"
21
- )
22
- return reader.read(file_path)
 
2
  from .json_reader import JsonReader
3
  from .jsonl_reader import JsonlReader
4
  from .txt_reader import TxtReader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/splitter/__init__.py CHANGED
@@ -1,31 +1,4 @@
1
- from functools import lru_cache
2
- from typing import Union
3
-
4
  from .recursive_character_splitter import (
5
  ChineseRecursiveTextSplitter,
6
  RecursiveCharacterSplitter,
7
  )
8
-
9
- _MAPPING = {
10
- "en": RecursiveCharacterSplitter,
11
- "zh": ChineseRecursiveTextSplitter,
12
- }
13
-
14
- SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
15
-
16
-
17
- @lru_cache(maxsize=None)
18
- def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
19
- cls = _MAPPING[language]
20
- kwargs = dict(frozen_kwargs)
21
- return cls(**kwargs)
22
-
23
-
24
- def split_chunks(text: str, language: str = "en", **kwargs) -> list:
25
- if language not in _MAPPING:
26
- raise ValueError(
27
- f"Unsupported language: {language}. "
28
- f"Supported languages are: {list(_MAPPING.keys())}"
29
- )
30
- splitter = _get_splitter(language, frozenset(kwargs.items()))
31
- return splitter.split_text(text)
 
 
 
 
1
  from .recursive_character_splitter import (
2
  ChineseRecursiveTextSplitter,
3
  RecursiveCharacterSplitter,
4
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/tokenizer/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ from graphgen.bases import BaseTokenizer
5
+
6
+ from .hf_tokenizer import HFTokenizer
7
+ from .tiktoken_tokenizer import TiktokenTokenizer
8
+
9
+ try:
10
+ from transformers import AutoTokenizer
11
+
12
+ _HF_AVAILABLE = True
13
+ except ImportError:
14
+ _HF_AVAILABLE = False
15
+
16
+
17
+ def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
18
+ import tiktoken
19
+
20
+ if tokenizer_name in tiktoken.list_encoding_names():
21
+ return TiktokenTokenizer(model_name=tokenizer_name)
22
+
23
+ # 2. HuggingFace
24
+ if _HF_AVAILABLE:
25
+ return HFTokenizer(model_name=tokenizer_name)
26
+
27
+ raise ValueError(
28
+ f"Unknown tokenizer {tokenizer_name} and HuggingFace not available."
29
+ )
30
+
31
+
32
+ @dataclass
33
+ class Tokenizer(BaseTokenizer):
34
+ """
35
+ Encapsulates different tokenization implementations based on the specified model name.
36
+ """
37
+
38
+ model_name: str = "cl100k_base"
39
+ _impl: BaseTokenizer = field(init=False, repr=False)
40
+
41
+ def __post_init__(self):
42
+ self._impl = get_tokenizer_impl(self.model_name)
43
+
44
+ def encode(self, text: str) -> List[int]:
45
+ return self._impl.encode(text)
46
+
47
+ def decode(self, token_ids: List[int]) -> str:
48
+ return self._impl.decode(token_ids)
49
+
50
+ def count_tokens(self, text: str) -> int:
51
+ return self._impl.count_tokens(text)
graphgen/models/tokenizer/hf_tokenizer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ from transformers import AutoTokenizer
5
+
6
+ from graphgen.bases import BaseTokenizer
7
+
8
+
9
+ @dataclass
10
+ class HFTokenizer(BaseTokenizer):
11
+ def __post_init__(self):
12
+ self.enc = AutoTokenizer.from_pretrained(self.model_name)
13
+
14
+ def encode(self, text: str) -> List[int]:
15
+ return self.enc.encode(text, add_special_tokens=False)
16
+
17
+ def decode(self, token_ids: List[int]) -> str:
18
+ return self.enc.decode(token_ids, skip_special_tokens=True)
graphgen/models/tokenizer/tiktoken_tokenizer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ import tiktoken
5
+
6
+ from graphgen.bases import BaseTokenizer
7
+
8
+
9
+ @dataclass
10
+ class TiktokenTokenizer(BaseTokenizer):
11
+ def __post_init__(self):
12
+ self.enc = tiktoken.get_encoding(self.model_name)
13
+
14
+ def encode(self, text: str) -> List[int]:
15
+ return self.enc.encode(text)
16
+
17
+ def decode(self, token_ids: List[int]) -> str:
18
+ return self.enc.decode(token_ids)
graphgen/operators/__init__.py CHANGED
@@ -1,22 +1,13 @@
 
1
  from graphgen.operators.generate.generate_cot import generate_cot
2
- from graphgen.operators.kg.extract_kg import extract_kg
3
  from graphgen.operators.search.search_all import search_all
4
 
5
  from .judge import judge_statement
6
  from .quiz import quiz
 
 
7
  from .traverse_graph import (
8
  traverse_graph_for_aggregated,
9
  traverse_graph_for_atomic,
10
  traverse_graph_for_multi_hop,
11
  )
12
-
13
- __all__ = [
14
- "extract_kg",
15
- "quiz",
16
- "judge_statement",
17
- "search_all",
18
- "traverse_graph_for_aggregated",
19
- "traverse_graph_for_atomic",
20
- "traverse_graph_for_multi_hop",
21
- "generate_cot",
22
- ]
 
1
+ from graphgen.operators.build_kg.extract_kg import extract_kg
2
  from graphgen.operators.generate.generate_cot import generate_cot
 
3
  from graphgen.operators.search.search_all import search_all
4
 
5
  from .judge import judge_statement
6
  from .quiz import quiz
7
+ from .read import read_files
8
+ from .split import chunk_documents
9
  from .traverse_graph import (
10
  traverse_graph_for_aggregated,
11
  traverse_graph_for_atomic,
12
  traverse_graph_for_multi_hop,
13
  )
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/build_kg/__init__.py ADDED
File without changes
graphgen/operators/build_kg/extract_kg.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import defaultdict
3
+ from typing import List
4
+
5
+ import gradio as gr
6
+
7
+ from graphgen.bases.base_storage import BaseGraphStorage
8
+ from graphgen.bases.datatypes import Chunk
9
+ from graphgen.models import OpenAIClient, Tokenizer
10
+ from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes
11
+ from graphgen.templates import KG_EXTRACTION_PROMPT
12
+ from graphgen.utils import (
13
+ detect_if_chinese,
14
+ handle_single_entity_extraction,
15
+ handle_single_relationship_extraction,
16
+ logger,
17
+ pack_history_conversations,
18
+ run_concurrent,
19
+ split_string_by_multi_markers,
20
+ )
21
+
22
+
23
+ # pylint: disable=too-many-statements
24
+ async def extract_kg(
25
+ llm_client: OpenAIClient,
26
+ kg_instance: BaseGraphStorage,
27
+ tokenizer_instance: Tokenizer,
28
+ chunks: List[Chunk],
29
+ progress_bar: gr.Progress = None,
30
+ ):
31
+ """
32
+ :param llm_client: Synthesizer LLM model to extract entities and relationships
33
+ :param kg_instance
34
+ :param tokenizer_instance
35
+ :param chunks
36
+ :param progress_bar: Gradio progress bar to show the progress of the extraction
37
+ :return:
38
+ """
39
+
40
+ async def _process_single_content(chunk: Chunk, max_loop: int = 3):
41
+ chunk_id = chunk.id
42
+ content = chunk.content
43
+ if detect_if_chinese(content):
44
+ language = "Chinese"
45
+ else:
46
+ language = "English"
47
+ KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
48
+
49
+ hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
50
+ **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
51
+ )
52
+
53
+ final_result = await llm_client.generate_answer(hint_prompt)
54
+ logger.info("First result: %s", final_result)
55
+
56
+ history = pack_history_conversations(hint_prompt, final_result)
57
+ for loop_index in range(max_loop):
58
+ if_loop_result = await llm_client.generate_answer(
59
+ text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
60
+ )
61
+ if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
62
+ if if_loop_result != "yes":
63
+ break
64
+
65
+ glean_result = await llm_client.generate_answer(
66
+ text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
67
+ )
68
+ logger.info("Loop %s glean: %s", loop_index, glean_result)
69
+
70
+ history += pack_history_conversations(
71
+ KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
72
+ )
73
+ final_result += glean_result
74
+ if loop_index == max_loop - 1:
75
+ break
76
+
77
+ records = split_string_by_multi_markers(
78
+ final_result,
79
+ [
80
+ KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
81
+ KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
82
+ ],
83
+ )
84
+
85
+ nodes = defaultdict(list)
86
+ edges = defaultdict(list)
87
+
88
+ for record in records:
89
+ record = re.search(r"\((.*)\)", record)
90
+ if record is None:
91
+ continue
92
+ record = record.group(1) # 提取括号内的内容
93
+ record_attributes = split_string_by_multi_markers(
94
+ record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
95
+ )
96
+
97
+ entity = await handle_single_entity_extraction(record_attributes, chunk_id)
98
+ if entity is not None:
99
+ nodes[entity["entity_name"]].append(entity)
100
+ continue
101
+ relation = await handle_single_relationship_extraction(
102
+ record_attributes, chunk_id
103
+ )
104
+ if relation is not None:
105
+ edges[(relation["src_id"], relation["tgt_id"])].append(relation)
106
+ return dict(nodes), dict(edges)
107
+
108
+ results = await run_concurrent(
109
+ _process_single_content,
110
+ chunks,
111
+ desc="[2/4]Extracting entities and relationships from chunks",
112
+ unit="chunk",
113
+ progress_bar=progress_bar,
114
+ )
115
+
116
+ nodes = defaultdict(list)
117
+ edges = defaultdict(list)
118
+ for n, e in results:
119
+ for k, v in n.items():
120
+ nodes[k].extend(v)
121
+ for k, v in e.items():
122
+ edges[tuple(sorted(k))].extend(v)
123
+
124
+ await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
125
+ await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
126
+
127
+ return kg_instance
graphgen/operators/{kg → build_kg}/merge_kg.py RENAMED
@@ -3,8 +3,8 @@ from collections import Counter
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.bases.base_storage import BaseGraphStorage
7
- from graphgen.models import Tokenizer, TopkTokenModel
8
  from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
9
  from graphgen.utils import detect_main_language, logger
10
  from graphgen.utils.format import split_string_by_multi_markers
@@ -13,7 +13,7 @@ from graphgen.utils.format import split_string_by_multi_markers
13
  async def _handle_kg_summary(
14
  entity_or_relation_name: str,
15
  description: str,
16
- llm_client: TopkTokenModel,
17
  tokenizer_instance: Tokenizer,
18
  max_summary_tokens: int = 200,
19
  ) -> str:
@@ -34,11 +34,11 @@ async def _handle_kg_summary(
34
  language = "Chinese"
35
  KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
36
 
37
- tokens = tokenizer_instance.encode_string(description)
38
  if len(tokens) < max_summary_tokens:
39
  return description
40
 
41
- use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
42
  prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
43
  entity_name=entity_or_relation_name,
44
  description_list=use_description.split("<SEP>"),
@@ -54,7 +54,7 @@ async def _handle_kg_summary(
54
  async def merge_nodes(
55
  nodes_data: dict,
56
  kg_instance: BaseGraphStorage,
57
- llm_client: TopkTokenModel,
58
  tokenizer_instance: Tokenizer,
59
  max_concurrent: int = 1000,
60
  ):
@@ -131,7 +131,7 @@ async def merge_nodes(
131
  async def merge_edges(
132
  edges_data: dict,
133
  kg_instance: BaseGraphStorage,
134
- llm_client: TopkTokenModel,
135
  tokenizer_instance: Tokenizer,
136
  max_concurrent: int = 1000,
137
  ):
 
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
+ from graphgen.bases import BaseGraphStorage, BaseLLMClient
7
+ from graphgen.models import Tokenizer
8
  from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
9
  from graphgen.utils import detect_main_language, logger
10
  from graphgen.utils.format import split_string_by_multi_markers
 
13
  async def _handle_kg_summary(
14
  entity_or_relation_name: str,
15
  description: str,
16
+ llm_client: BaseLLMClient,
17
  tokenizer_instance: Tokenizer,
18
  max_summary_tokens: int = 200,
19
  ) -> str:
 
34
  language = "Chinese"
35
  KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
36
 
37
+ tokens = tokenizer_instance.encode(description)
38
  if len(tokens) < max_summary_tokens:
39
  return description
40
 
41
+ use_description = tokenizer_instance.decode(tokens[:max_summary_tokens])
42
  prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
43
  entity_name=entity_or_relation_name,
44
  description_list=use_description.split("<SEP>"),
 
54
  async def merge_nodes(
55
  nodes_data: dict,
56
  kg_instance: BaseGraphStorage,
57
+ llm_client: BaseLLMClient,
58
  tokenizer_instance: Tokenizer,
59
  max_concurrent: int = 1000,
60
  ):
 
131
  async def merge_edges(
132
  edges_data: dict,
133
  kg_instance: BaseGraphStorage,
134
+ llm_client: BaseLLMClient,
135
  tokenizer_instance: Tokenizer,
136
  max_concurrent: int = 1000,
137
  ):
graphgen/operators/{kg → build_kg}/split_kg.py RENAMED
File without changes
graphgen/operators/generate/generate_cot.py CHANGED
@@ -3,14 +3,14 @@ from typing import Dict, List, Tuple
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
7
  from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
8
  from graphgen.utils import compute_content_hash, detect_main_language
9
 
10
 
11
  async def generate_cot(
12
  graph_storage: NetworkXStorage,
13
- synthesizer_llm_client: OpenAIModel,
14
  method_params: Dict = None,
15
  ):
16
  method = method_params.get("method", "leiden")
 
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
+ from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIClient
7
  from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
8
  from graphgen.utils import compute_content_hash, detect_main_language
9
 
10
 
11
  async def generate_cot(
12
  graph_storage: NetworkXStorage,
13
+ synthesizer_llm_client: OpenAIClient,
14
  method_params: Dict = None,
15
  ):
16
  method = method_params.get("method", "leiden")
graphgen/operators/judge.py CHANGED
@@ -3,13 +3,13 @@ import math
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
- from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIModel
7
  from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
8
  from graphgen.utils import logger, yes_no_loss_entropy
9
 
10
 
11
  async def judge_statement( # pylint: disable=too-many-statements
12
- trainee_llm_client: OpenAIModel,
13
  graph_storage: NetworkXStorage,
14
  rephrase_storage: JsonKVStorage,
15
  re_judge: bool = False,
 
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
 
6
+ from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
7
  from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
8
  from graphgen.utils import logger, yes_no_loss_entropy
9
 
10
 
11
  async def judge_statement( # pylint: disable=too-many-statements
12
+ trainee_llm_client: OpenAIClient,
13
  graph_storage: NetworkXStorage,
14
  rephrase_storage: JsonKVStorage,
15
  re_judge: bool = False,
graphgen/operators/kg/extract_kg.py DELETED
@@ -1,152 +0,0 @@
1
- import asyncio
2
- import re
3
- from collections import defaultdict
4
- from typing import List
5
-
6
- import gradio as gr
7
- from tqdm.asyncio import tqdm as tqdm_async
8
-
9
- from graphgen.bases.base_storage import BaseGraphStorage
10
- from graphgen.bases.datatypes import Chunk
11
- from graphgen.models import OpenAIModel, Tokenizer
12
- from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
13
- from graphgen.templates import KG_EXTRACTION_PROMPT
14
- from graphgen.utils import (
15
- detect_if_chinese,
16
- handle_single_entity_extraction,
17
- handle_single_relationship_extraction,
18
- logger,
19
- pack_history_conversations,
20
- split_string_by_multi_markers,
21
- )
22
-
23
-
24
- # pylint: disable=too-many-statements
25
- async def extract_kg(
26
- llm_client: OpenAIModel,
27
- kg_instance: BaseGraphStorage,
28
- tokenizer_instance: Tokenizer,
29
- chunks: List[Chunk],
30
- progress_bar: gr.Progress = None,
31
- max_concurrent: int = 1000,
32
- ):
33
- """
34
- :param llm_client: Synthesizer LLM model to extract entities and relationships
35
- :param kg_instance
36
- :param tokenizer_instance
37
- :param chunks
38
- :param progress_bar: Gradio progress bar to show the progress of the extraction
39
- :param max_concurrent
40
- :return:
41
- """
42
-
43
- semaphore = asyncio.Semaphore(max_concurrent)
44
-
45
- async def _process_single_content(chunk: Chunk, max_loop: int = 3):
46
- async with semaphore:
47
- chunk_id = chunk.id
48
- content = chunk.content
49
- if detect_if_chinese(content):
50
- language = "Chinese"
51
- else:
52
- language = "English"
53
- KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
54
-
55
- hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
56
- **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
57
- )
58
-
59
- final_result = await llm_client.generate_answer(hint_prompt)
60
- logger.info("First result: %s", final_result)
61
-
62
- history = pack_history_conversations(hint_prompt, final_result)
63
- for loop_index in range(max_loop):
64
- if_loop_result = await llm_client.generate_answer(
65
- text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
66
- )
67
- if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
68
- if if_loop_result != "yes":
69
- break
70
-
71
- glean_result = await llm_client.generate_answer(
72
- text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
73
- )
74
- logger.info("Loop %s glean: %s", loop_index, glean_result)
75
-
76
- history += pack_history_conversations(
77
- KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
78
- )
79
- final_result += glean_result
80
- if loop_index == max_loop - 1:
81
- break
82
-
83
- records = split_string_by_multi_markers(
84
- final_result,
85
- [
86
- KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
87
- KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
88
- ],
89
- )
90
-
91
- nodes = defaultdict(list)
92
- edges = defaultdict(list)
93
-
94
- for record in records:
95
- record = re.search(r"\((.*)\)", record)
96
- if record is None:
97
- continue
98
- record = record.group(1) # 提取括号内的内容
99
- record_attributes = split_string_by_multi_markers(
100
- record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
101
- )
102
-
103
- entity = await handle_single_entity_extraction(
104
- record_attributes, chunk_id
105
- )
106
- if entity is not None:
107
- nodes[entity["entity_name"]].append(entity)
108
- continue
109
- relation = await handle_single_relationship_extraction(
110
- record_attributes, chunk_id
111
- )
112
- if relation is not None:
113
- edges[(relation["src_id"], relation["tgt_id"])].append(relation)
114
- return dict(nodes), dict(edges)
115
-
116
- results = []
117
- chunk_number = len(chunks)
118
- async for result in tqdm_async(
119
- asyncio.as_completed([_process_single_content(c) for c in chunks]),
120
- total=len(chunks),
121
- desc="[2/4]Extracting entities and relationships from chunks",
122
- unit="chunk",
123
- ):
124
- try:
125
- if progress_bar is not None:
126
- progress_bar(
127
- len(results) / chunk_number,
128
- desc="[3/4]Extracting entities and relationships from chunks",
129
- )
130
- results.append(await result)
131
- if progress_bar is not None and len(results) == chunk_number:
132
- progress_bar(
133
- 1, desc="[3/4]Extracting entities and relationships from chunks"
134
- )
135
- except Exception as e: # pylint: disable=broad-except
136
- logger.error(
137
- "Error occurred while extracting entities and relationships from chunks: %s",
138
- e,
139
- )
140
-
141
- nodes = defaultdict(list)
142
- edges = defaultdict(list)
143
- for n, e in results:
144
- for k, v in n.items():
145
- nodes[k].extend(v)
146
- for k, v in e.items():
147
- edges[tuple(sorted(k))].extend(v)
148
-
149
- await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
150
- await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
151
-
152
- return kg_instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/preprocess/resolute_coreference.py CHANGED
@@ -1,13 +1,13 @@
1
  from typing import List
2
 
3
  from graphgen.bases.datatypes import Chunk
4
- from graphgen.models import OpenAIModel
5
  from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
6
  from graphgen.utils import detect_main_language
7
 
8
 
9
  async def resolute_coreference(
10
- llm_client: OpenAIModel, chunks: List[Chunk]
11
  ) -> List[Chunk]:
12
  """
13
  Resolute conference
 
1
  from typing import List
2
 
3
  from graphgen.bases.datatypes import Chunk
4
+ from graphgen.models import OpenAIClient
5
  from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
6
  from graphgen.utils import detect_main_language
7
 
8
 
9
  async def resolute_coreference(
10
+ llm_client: OpenAIClient, chunks: List[Chunk]
11
  ) -> List[Chunk]:
12
  """
13
  Resolute conference
graphgen/operators/quiz.py CHANGED
@@ -2,17 +2,19 @@ import asyncio
2
  from collections import defaultdict
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
- from graphgen.models import JsonKVStorage, OpenAIModel, NetworkXStorage
6
- from graphgen.utils import logger, detect_main_language
7
  from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
 
8
 
9
 
10
  async def quiz(
11
- synth_llm_client: OpenAIModel,
12
- graph_storage: NetworkXStorage,
13
- rephrase_storage: JsonKVStorage,
14
- max_samples: int = 1,
15
- max_concurrent: int = 1000) -> JsonKVStorage:
 
16
  """
17
  Get all edges and quiz them
18
 
@@ -26,11 +28,7 @@ async def quiz(
26
 
27
  semaphore = asyncio.Semaphore(max_concurrent)
28
 
29
- async def _process_single_quiz(
30
- des: str,
31
- prompt: str,
32
- gt: str
33
- ):
34
  async with semaphore:
35
  try:
36
  # 如果在rephrase_storage中已经存在,直接取出
@@ -39,16 +37,14 @@ async def quiz(
39
  return None
40
 
41
  new_description = await synth_llm_client.generate_answer(
42
- prompt,
43
- temperature=1
44
  )
45
- return {des: [(new_description, gt)]}
46
 
47
- except Exception as e: # pylint: disable=broad-except
48
  logger.error("Error when quizzing description %s: %s", des, e)
49
  return None
50
 
51
-
52
  edges = await graph_storage.get_all_edges()
53
  nodes = await graph_storage.get_all_nodes()
54
 
@@ -60,41 +56,59 @@ async def quiz(
60
  description = edge_data["description"]
61
  language = "English" if detect_main_language(description) == "en" else "Chinese"
62
 
63
- results[description] = [(description, 'yes')]
64
 
65
  for i in range(max_samples):
66
  if i > 0:
67
  tasks.append(
68
- _process_single_quiz(description,
69
- DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
70
- input_sentence=description), 'yes')
 
 
 
 
71
  )
72
- tasks.append(_process_single_quiz(description,
73
- DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
74
- input_sentence=description), 'no'))
 
 
 
 
 
 
75
 
76
  for node in nodes:
77
  node_data = node[1]
78
  description = node_data["description"]
79
  language = "English" if detect_main_language(description) == "en" else "Chinese"
80
 
81
- results[description] = [(description, 'yes')]
82
 
83
  for i in range(max_samples):
84
  if i > 0:
85
  tasks.append(
86
- _process_single_quiz(description,
87
- DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
88
- input_sentence=description), 'yes')
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
- tasks.append(_process_single_quiz(description,
91
- DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
92
- input_sentence=description), 'no'))
93
 
94
  for result in tqdm_async(
95
- asyncio.as_completed(tasks),
96
- total=len(tasks),
97
- desc="Quizzing descriptions"
98
  ):
99
  new_result = await result
100
  if new_result:
@@ -105,5 +119,4 @@ async def quiz(
105
  results[key] = list(set(value))
106
  await rephrase_storage.upsert({key: results[key]})
107
 
108
-
109
  return rephrase_storage
 
2
  from collections import defaultdict
3
 
4
  from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
7
  from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
8
+ from graphgen.utils import detect_main_language, logger
9
 
10
 
11
  async def quiz(
12
+ synth_llm_client: OpenAIClient,
13
+ graph_storage: NetworkXStorage,
14
+ rephrase_storage: JsonKVStorage,
15
+ max_samples: int = 1,
16
+ max_concurrent: int = 1000,
17
+ ) -> JsonKVStorage:
18
  """
19
  Get all edges and quiz them
20
 
 
28
 
29
  semaphore = asyncio.Semaphore(max_concurrent)
30
 
31
+ async def _process_single_quiz(des: str, prompt: str, gt: str):
 
 
 
 
32
  async with semaphore:
33
  try:
34
  # 如果在rephrase_storage中已经存在,直接取出
 
37
  return None
38
 
39
  new_description = await synth_llm_client.generate_answer(
40
+ prompt, temperature=1
 
41
  )
42
+ return {des: [(new_description, gt)]}
43
 
44
+ except Exception as e: # pylint: disable=broad-except
45
  logger.error("Error when quizzing description %s: %s", des, e)
46
  return None
47
 
 
48
  edges = await graph_storage.get_all_edges()
49
  nodes = await graph_storage.get_all_nodes()
50
 
 
56
  description = edge_data["description"]
57
  language = "English" if detect_main_language(description) == "en" else "Chinese"
58
 
59
+ results[description] = [(description, "yes")]
60
 
61
  for i in range(max_samples):
62
  if i > 0:
63
  tasks.append(
64
+ _process_single_quiz(
65
+ description,
66
+ DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format(
67
+ input_sentence=description
68
+ ),
69
+ "yes",
70
+ )
71
  )
72
+ tasks.append(
73
+ _process_single_quiz(
74
+ description,
75
+ DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format(
76
+ input_sentence=description
77
+ ),
78
+ "no",
79
+ )
80
+ )
81
 
82
  for node in nodes:
83
  node_data = node[1]
84
  description = node_data["description"]
85
  language = "English" if detect_main_language(description) == "en" else "Chinese"
86
 
87
+ results[description] = [(description, "yes")]
88
 
89
  for i in range(max_samples):
90
  if i > 0:
91
  tasks.append(
92
+ _process_single_quiz(
93
+ description,
94
+ DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format(
95
+ input_sentence=description
96
+ ),
97
+ "yes",
98
+ )
99
+ )
100
+ tasks.append(
101
+ _process_single_quiz(
102
+ description,
103
+ DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format(
104
+ input_sentence=description
105
+ ),
106
+ "no",
107
  )
108
+ )
 
 
109
 
110
  for result in tqdm_async(
111
+ asyncio.as_completed(tasks), total=len(tasks), desc="Quizzing descriptions"
 
 
112
  ):
113
  new_result = await result
114
  if new_result:
 
119
  results[key] = list(set(value))
120
  await rephrase_storage.upsert({key: results[key]})
121
 
 
122
  return rephrase_storage
graphgen/operators/read/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .read_files import read_files
graphgen/operators/read/read_files.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from graphgen.models import CsvReader, JsonlReader, JsonReader, TxtReader
2
+
3
+ _MAPPING = {
4
+ "jsonl": JsonlReader,
5
+ "json": JsonReader,
6
+ "txt": TxtReader,
7
+ "csv": CsvReader,
8
+ }
9
+
10
+
11
+ def read_files(file_path: str):
12
+ suffix = file_path.split(".")[-1]
13
+ if suffix in _MAPPING:
14
+ reader = _MAPPING[suffix]()
15
+ else:
16
+ raise ValueError(
17
+ f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}"
18
+ )
19
+ return reader.read(file_path)
graphgen/operators/split/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .split_chunks import chunk_documents
graphgen/operators/split/split_chunks.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Union
3
+
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import (
7
+ ChineseRecursiveTextSplitter,
8
+ RecursiveCharacterSplitter,
9
+ Tokenizer,
10
+ )
11
+ from graphgen.utils import compute_content_hash, detect_main_language
12
+
13
+ _MAPPING = {
14
+ "en": RecursiveCharacterSplitter,
15
+ "zh": ChineseRecursiveTextSplitter,
16
+ }
17
+
18
+ SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
19
+
20
+
21
+ @lru_cache(maxsize=None)
22
+ def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
23
+ cls = _MAPPING[language]
24
+ kwargs = dict(frozen_kwargs)
25
+ return cls(**kwargs)
26
+
27
+
28
+ def split_chunks(text: str, language: str = "en", **kwargs) -> list:
29
+ if language not in _MAPPING:
30
+ raise ValueError(
31
+ f"Unsupported language: {language}. "
32
+ f"Supported languages are: {list(_MAPPING.keys())}"
33
+ )
34
+ splitter = _get_splitter(language, frozenset(kwargs.items()))
35
+ return splitter.split_text(text)
36
+
37
+
38
+ async def chunk_documents(
39
+ new_docs: dict,
40
+ chunk_size: int = 1024,
41
+ chunk_overlap: int = 100,
42
+ tokenizer_instance: Tokenizer = None,
43
+ progress_bar=None,
44
+ ) -> dict:
45
+ inserting_chunks = {}
46
+ cur_index = 1
47
+ doc_number = len(new_docs)
48
+ async for doc_key, doc in tqdm_async(
49
+ new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
50
+ ):
51
+ doc_language = detect_main_language(doc["content"])
52
+ text_chunks = split_chunks(
53
+ doc["content"],
54
+ language=doc_language,
55
+ chunk_size=chunk_size,
56
+ chunk_overlap=chunk_overlap,
57
+ )
58
+
59
+ chunks = {
60
+ compute_content_hash(txt, prefix="chunk-"): {
61
+ "content": txt,
62
+ "full_doc_id": doc_key,
63
+ "length": len(tokenizer_instance.encode(txt))
64
+ if tokenizer_instance
65
+ else len(txt),
66
+ "language": doc_language,
67
+ }
68
+ for txt in text_chunks
69
+ }
70
+ inserting_chunks.update(chunks)
71
+
72
+ if progress_bar is not None:
73
+ progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
74
+ cur_index += 1
75
+
76
+ return inserting_chunks
graphgen/operators/traverse_graph.py CHANGED
@@ -6,11 +6,11 @@ from tqdm.asyncio import tqdm as tqdm_async
6
  from graphgen.models import (
7
  JsonKVStorage,
8
  NetworkXStorage,
9
- OpenAIModel,
10
  Tokenizer,
11
  TraverseStrategy,
12
  )
13
- from graphgen.operators.kg.split_kg import get_batches_with_strategy
14
  from graphgen.templates import (
15
  ANSWER_REPHRASING_PROMPT,
16
  MULTI_HOP_GENERATION_PROMPT,
@@ -30,7 +30,7 @@ async def _pre_tokenize(
30
  if "length" not in edge[2]:
31
  edge[2]["length"] = len(
32
  await asyncio.get_event_loop().run_in_executor(
33
- None, tokenizer.encode_string, edge[2]["description"]
34
  )
35
  )
36
  return edge
@@ -40,7 +40,7 @@ async def _pre_tokenize(
40
  if "length" not in node[1]:
41
  node[1]["length"] = len(
42
  await asyncio.get_event_loop().run_in_executor(
43
- None, tokenizer.encode_string, node[1]["description"]
44
  )
45
  )
46
  return node
@@ -161,7 +161,7 @@ def _post_process_synthetic_data(data):
161
 
162
 
163
  async def traverse_graph_for_aggregated(
164
- llm_client: OpenAIModel,
165
  tokenizer: Tokenizer,
166
  graph_storage: NetworkXStorage,
167
  traverse_strategy: TraverseStrategy,
@@ -310,7 +310,7 @@ async def traverse_graph_for_aggregated(
310
 
311
  # pylint: disable=too-many-branches, too-many-statements
312
  async def traverse_graph_for_atomic(
313
- llm_client: OpenAIModel,
314
  tokenizer: Tokenizer,
315
  graph_storage: NetworkXStorage,
316
  traverse_strategy: TraverseStrategy,
@@ -426,7 +426,7 @@ async def traverse_graph_for_atomic(
426
 
427
 
428
  async def traverse_graph_for_multi_hop(
429
- llm_client: OpenAIModel,
430
  tokenizer: Tokenizer,
431
  graph_storage: NetworkXStorage,
432
  traverse_strategy: TraverseStrategy,
 
6
  from graphgen.models import (
7
  JsonKVStorage,
8
  NetworkXStorage,
9
+ OpenAIClient,
10
  Tokenizer,
11
  TraverseStrategy,
12
  )
13
+ from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
14
  from graphgen.templates import (
15
  ANSWER_REPHRASING_PROMPT,
16
  MULTI_HOP_GENERATION_PROMPT,
 
30
  if "length" not in edge[2]:
31
  edge[2]["length"] = len(
32
  await asyncio.get_event_loop().run_in_executor(
33
+ None, tokenizer.encode, edge[2]["description"]
34
  )
35
  )
36
  return edge
 
40
  if "length" not in node[1]:
41
  node[1]["length"] = len(
42
  await asyncio.get_event_loop().run_in_executor(
43
+ None, tokenizer.encode, node[1]["description"]
44
  )
45
  )
46
  return node
 
161
 
162
 
163
  async def traverse_graph_for_aggregated(
164
+ llm_client: OpenAIClient,
165
  tokenizer: Tokenizer,
166
  graph_storage: NetworkXStorage,
167
  traverse_strategy: TraverseStrategy,
 
310
 
311
  # pylint: disable=too-many-branches, too-many-statements
312
  async def traverse_graph_for_atomic(
313
+ llm_client: OpenAIClient,
314
  tokenizer: Tokenizer,
315
  graph_storage: NetworkXStorage,
316
  traverse_strategy: TraverseStrategy,
 
426
 
427
 
428
  async def traverse_graph_for_multi_hop(
429
+ llm_client: OpenAIClient,
430
  tokenizer: Tokenizer,
431
  graph_storage: NetworkXStorage,
432
  traverse_strategy: TraverseStrategy,
graphgen/utils/__init__.py CHANGED
@@ -13,3 +13,5 @@ from .hash import compute_args_hash, compute_content_hash
13
  from .help_nltk import NLTKHelper
14
  from .log import logger, parse_log, set_logger
15
  from .loop import create_event_loop
 
 
 
13
  from .help_nltk import NLTKHelper
14
  from .log import logger, parse_log, set_logger
15
  from .loop import create_event_loop
16
+ from .run_concurrent import run_concurrent
17
+ from .wrap import async_to_sync_method
graphgen/utils/calculate_confidence.py CHANGED
@@ -1,34 +1,41 @@
1
  import math
2
  from typing import List
3
- from graphgen.models.llm.topk_token_model import Token
 
 
4
 
5
  def preprocess_tokens(tokens: List[Token]) -> List[Token]:
6
  """Preprocess tokens for calculating confidence."""
7
  tokens = [x for x in tokens if x.prob > 0]
8
  return tokens
9
 
 
10
  def joint_probability(tokens: List[Token]) -> float:
11
  """Calculate joint probability of a list of tokens."""
12
  tokens = preprocess_tokens(tokens)
13
  logprob_sum = sum(x.logprob for x in tokens)
14
  return math.exp(logprob_sum / len(tokens))
15
 
 
16
  def min_prob(tokens: List[Token]) -> float:
17
  """Calculate the minimum probability of a list of tokens."""
18
  tokens = preprocess_tokens(tokens)
19
  return min(x.prob for x in tokens)
20
 
 
21
  def average_prob(tokens: List[Token]) -> float:
22
  """Calculate the average probability of a list of tokens."""
23
  tokens = preprocess_tokens(tokens)
24
  return sum(x.prob for x in tokens) / len(tokens)
25
 
 
26
  def average_confidence(tokens: List[Token]) -> float:
27
  """Calculate the average confidence of a list of tokens."""
28
  tokens = preprocess_tokens(tokens)
29
  confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens]
30
  return sum(confidence) / len(tokens)
31
 
 
32
  def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
33
  """Calculate the loss for yes/no question."""
34
  losses = []
@@ -41,7 +48,10 @@ def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> floa
41
  losses.append(token.prob)
42
  return sum(losses) / len(losses)
43
 
44
- def yes_no_loss_entropy(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
 
 
 
45
  """Calculate the loss for yes/no question using entropy."""
46
  losses = []
47
  for i, tokens in enumerate(tokens_list):
 
1
  import math
2
  from typing import List
3
+
4
+ from graphgen.bases.datatypes import Token
5
+
6
 
7
  def preprocess_tokens(tokens: List[Token]) -> List[Token]:
8
  """Preprocess tokens for calculating confidence."""
9
  tokens = [x for x in tokens if x.prob > 0]
10
  return tokens
11
 
12
+
13
  def joint_probability(tokens: List[Token]) -> float:
14
  """Calculate joint probability of a list of tokens."""
15
  tokens = preprocess_tokens(tokens)
16
  logprob_sum = sum(x.logprob for x in tokens)
17
  return math.exp(logprob_sum / len(tokens))
18
 
19
+
20
  def min_prob(tokens: List[Token]) -> float:
21
  """Calculate the minimum probability of a list of tokens."""
22
  tokens = preprocess_tokens(tokens)
23
  return min(x.prob for x in tokens)
24
 
25
+
26
  def average_prob(tokens: List[Token]) -> float:
27
  """Calculate the average probability of a list of tokens."""
28
  tokens = preprocess_tokens(tokens)
29
  return sum(x.prob for x in tokens) / len(tokens)
30
 
31
+
32
  def average_confidence(tokens: List[Token]) -> float:
33
  """Calculate the average confidence of a list of tokens."""
34
  tokens = preprocess_tokens(tokens)
35
  confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens]
36
  return sum(confidence) / len(tokens)
37
 
38
+
39
  def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
40
  """Calculate the loss for yes/no question."""
41
  losses = []
 
48
  losses.append(token.prob)
49
  return sum(losses) / len(losses)
50
 
51
+
52
+ def yes_no_loss_entropy(
53
+ tokens_list: List[List[Token]], ground_truth: List[str]
54
+ ) -> float:
55
  """Calculate the loss for yes/no question using entropy."""
56
  losses = []
57
  for i, tokens in enumerate(tokens_list):
graphgen/utils/run_concurrent.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Awaitable, Callable, List, Optional, TypeVar
3
+
4
+ import gradio as gr
5
+ from tqdm.asyncio import tqdm as tqdm_async
6
+
7
+ from graphgen.utils.log import logger
8
+
9
+ T = TypeVar("T")
10
+ R = TypeVar("R")
11
+
12
+
13
+ async def run_concurrent(
14
+ coro_fn: Callable[[T], Awaitable[R]],
15
+ items: List[T],
16
+ *,
17
+ desc: str = "processing",
18
+ unit: str = "item",
19
+ progress_bar: Optional[gr.Progress] = None,
20
+ ) -> List[R]:
21
+ tasks = [asyncio.create_task(coro_fn(it)) for it in items]
22
+
23
+ results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
24
+
25
+ ok_results = []
26
+ for idx, res in enumerate(results):
27
+ if isinstance(res, Exception):
28
+ logger.exception("Task failed: %s", res)
29
+ if progress_bar:
30
+ progress_bar((idx + 1) / len(items), desc=desc)
31
+ continue
32
+ ok_results.append(res)
33
+ if progress_bar:
34
+ progress_bar((idx + 1) / len(items), desc=desc)
35
+
36
+ if progress_bar:
37
+ progress_bar(1.0, desc=desc)
38
+ return ok_results
graphgen/utils/wrap.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from typing import Any, Callable
3
+
4
+ from .loop import create_event_loop
5
+
6
+
7
+ def async_to_sync_method(func: Callable) -> Callable:
8
+ @wraps(func)
9
+ def wrapper(self, *args, **kwargs) -> Any:
10
+ loop = create_event_loop()
11
+ return loop.run_until_complete(func(self, *args, **kwargs))
12
+
13
+ return wrapper
webui/app.py CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
9
  from dotenv import load_dotenv
10
 
11
  from graphgen.graphgen import GraphGen
12
- from graphgen.models import OpenAIModel, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
  from webui.base import WebuiParams
@@ -41,7 +41,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
41
 
42
  graph_gen = GraphGen(working_dir=working_dir, config=config)
43
  # Set up LLM clients
44
- graph_gen.synthesizer_llm_client = OpenAIModel(
45
  model_name=env.get("SYNTHESIZER_MODEL", ""),
46
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
47
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
@@ -50,7 +50,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
50
  tpm=TPM(env.get("TPM", 50000)),
51
  )
52
 
53
- graph_gen.trainee_llm_client = OpenAIModel(
54
  model_name=env.get("TRAINEE_MODEL", ""),
55
  base_url=env.get("TRAINEE_BASE_URL", ""),
56
  api_key=env.get("TRAINEE_API_KEY", ""),
 
9
  from dotenv import load_dotenv
10
 
11
  from graphgen.graphgen import GraphGen
12
+ from graphgen.models import OpenAIClient, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
  from webui.base import WebuiParams
 
41
 
42
  graph_gen = GraphGen(working_dir=working_dir, config=config)
43
  # Set up LLM clients
44
+ graph_gen.synthesizer_llm_client = OpenAIClient(
45
  model_name=env.get("SYNTHESIZER_MODEL", ""),
46
  base_url=env.get("SYNTHESIZER_BASE_URL", ""),
47
  api_key=env.get("SYNTHESIZER_API_KEY", ""),
 
50
  tpm=TPM(env.get("TPM", 50000)),
51
  )
52
 
53
+ graph_gen.trainee_llm_client = OpenAIClient(
54
  model_name=env.get("TRAINEE_MODEL", ""),
55
  base_url=env.get("TRAINEE_BASE_URL", ""),
56
  api_key=env.get("TRAINEE_API_KEY", ""),
webui/utils/count_tokens.py CHANGED
@@ -45,7 +45,7 @@ def count_tokens(file, tokenizer_name, data_frame):
45
  content = item.get("content", "")
46
  else:
47
  content = item
48
- token_count += len(tokenizer.encode_string(content))
49
 
50
  _update_data = [[str(token_count), str(token_count * 50), "N/A"]]
51
 
 
45
  content = item.get("content", "")
46
  else:
47
  content = item
48
+ token_count += len(tokenizer.encode(content))
49
 
50
  _update_data = [[str(token_count), str(token_count * 50), "N/A"]]
51