Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
3a3b216
1
Parent(s):
bda6eda
Auto-sync from demo at Tue Sep 30 03:30:14 UTC 2025
Browse files- app.py +3 -3
- graphgen/bases/__init__.py +12 -0
- graphgen/bases/base_kg_builder.py +41 -0
- graphgen/bases/base_llm_client.py +74 -0
- graphgen/bases/base_tokenizer.py +44 -0
- graphgen/bases/datatypes.py +14 -0
- graphgen/graphgen.py +53 -98
- graphgen/models/__init__.py +5 -5
- graphgen/models/evaluate/length_evaluator.py +2 -2
- graphgen/models/kg_builder/NetworkXKGBuilder.py +18 -0
- graphgen/{operators/kg → models/kg_builder}/__init__.py +0 -0
- graphgen/models/llm/ollama_client.py +21 -0
- graphgen/models/llm/{openai_model.py → openai_client.py} +43 -40
- graphgen/models/llm/tokenizer.py +0 -73
- graphgen/models/llm/topk_token_model.py +9 -16
- graphgen/models/reader/__init__.py +0 -18
- graphgen/models/splitter/__init__.py +0 -27
- graphgen/models/tokenizer/__init__.py +51 -0
- graphgen/models/tokenizer/hf_tokenizer.py +18 -0
- graphgen/models/tokenizer/tiktoken_tokenizer.py +18 -0
- graphgen/operators/__init__.py +3 -12
- graphgen/operators/build_kg/__init__.py +0 -0
- graphgen/operators/build_kg/extract_kg.py +127 -0
- graphgen/operators/{kg → build_kg}/merge_kg.py +7 -7
- graphgen/operators/{kg → build_kg}/split_kg.py +0 -0
- graphgen/operators/generate/generate_cot.py +2 -2
- graphgen/operators/judge.py +2 -2
- graphgen/operators/kg/extract_kg.py +0 -152
- graphgen/operators/preprocess/resolute_coreference.py +2 -2
- graphgen/operators/quiz.py +48 -35
- graphgen/operators/read/__init__.py +1 -0
- graphgen/operators/read/read_files.py +19 -0
- graphgen/operators/split/__init__.py +1 -0
- graphgen/operators/split/split_chunks.py +76 -0
- graphgen/operators/traverse_graph.py +7 -7
- graphgen/utils/__init__.py +2 -0
- graphgen/utils/calculate_confidence.py +12 -2
- graphgen/utils/run_concurrent.py +38 -0
- graphgen/utils/wrap.py +13 -0
- webui/app.py +3 -3
- webui/utils/count_tokens.py +1 -1
app.py
CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
|
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
from graphgen.graphgen import GraphGen
|
12 |
-
from graphgen.models import
|
13 |
from graphgen.models.llm.limitter import RPM, TPM
|
14 |
from graphgen.utils import set_logger
|
15 |
from webui.base import WebuiParams
|
@@ -41,7 +41,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
41 |
|
42 |
graph_gen = GraphGen(working_dir=working_dir, config=config)
|
43 |
# Set up LLM clients
|
44 |
-
graph_gen.synthesizer_llm_client =
|
45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
@@ -50,7 +50,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
50 |
tpm=TPM(env.get("TPM", 50000)),
|
51 |
)
|
52 |
|
53 |
-
graph_gen.trainee_llm_client =
|
54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
from graphgen.graphgen import GraphGen
|
12 |
+
from graphgen.models import OpenAIClient, Tokenizer
|
13 |
from graphgen.models.llm.limitter import RPM, TPM
|
14 |
from graphgen.utils import set_logger
|
15 |
from webui.base import WebuiParams
|
|
|
41 |
|
42 |
graph_gen = GraphGen(working_dir=working_dir, config=config)
|
43 |
# Set up LLM clients
|
44 |
+
graph_gen.synthesizer_llm_client = OpenAIClient(
|
45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
|
|
50 |
tpm=TPM(env.get("TPM", 50000)),
|
51 |
)
|
52 |
|
53 |
+
graph_gen.trainee_llm_client = OpenAIClient(
|
54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
graphgen/bases/__init__.py
CHANGED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_kg_builder import BaseKGBuilder
|
2 |
+
from .base_llm_client import BaseLLMClient
|
3 |
+
from .base_reader import BaseReader
|
4 |
+
from .base_splitter import BaseSplitter
|
5 |
+
from .base_storage import (
|
6 |
+
BaseGraphStorage,
|
7 |
+
BaseKVStorage,
|
8 |
+
BaseListStorage,
|
9 |
+
StorageNameSpace,
|
10 |
+
)
|
11 |
+
from .base_tokenizer import BaseTokenizer
|
12 |
+
from .datatypes import Chunk, QAPair, Token
|
graphgen/bases/base_kg_builder.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from collections import defaultdict
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Dict, List, Tuple
|
5 |
+
|
6 |
+
from graphgen.bases.base_llm_client import BaseLLMClient
|
7 |
+
from graphgen.bases.base_storage import BaseGraphStorage
|
8 |
+
from graphgen.bases.datatypes import Chunk
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class BaseKGBuilder(ABC):
|
13 |
+
kg_instance: BaseGraphStorage
|
14 |
+
llm_client: BaseLLMClient
|
15 |
+
|
16 |
+
_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
|
17 |
+
_edges: Dict[Tuple[str, str], List[dict]] = field(
|
18 |
+
default_factory=lambda: defaultdict(list)
|
19 |
+
)
|
20 |
+
|
21 |
+
def build(self, chunks: List[Chunk]) -> None:
|
22 |
+
pass
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
async def extract_all(self, chunks: List[Chunk]) -> None:
|
26 |
+
"""Extract nodes and edges from all chunks."""
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
async def extract(
|
31 |
+
self, chunk: Chunk
|
32 |
+
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
|
33 |
+
"""Extract nodes and edges from a single chunk."""
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
async def merge_nodes(
|
38 |
+
self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm
|
39 |
+
) -> None:
|
40 |
+
"""Merge extracted nodes into the knowledge graph."""
|
41 |
+
raise NotImplementedError
|
graphgen/bases/base_llm_client.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import abc
|
4 |
+
import re
|
5 |
+
from typing import Any, List, Optional
|
6 |
+
|
7 |
+
from graphgen.bases.base_tokenizer import BaseTokenizer
|
8 |
+
from graphgen.bases.datatypes import Token
|
9 |
+
|
10 |
+
|
11 |
+
class BaseLLMClient(abc.ABC):
|
12 |
+
"""
|
13 |
+
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
*,
|
19 |
+
system_prompt: str = "",
|
20 |
+
temperature: float = 0.0,
|
21 |
+
max_tokens: int = 4096,
|
22 |
+
repetition_penalty: float = 1.05,
|
23 |
+
top_p: float = 0.95,
|
24 |
+
top_k: int = 50,
|
25 |
+
tokenizer: Optional[BaseTokenizer] = None,
|
26 |
+
**kwargs: Any,
|
27 |
+
):
|
28 |
+
self.system_prompt = system_prompt
|
29 |
+
self.temperature = temperature
|
30 |
+
self.max_tokens = max_tokens
|
31 |
+
self.repetition_penalty = repetition_penalty
|
32 |
+
self.top_p = top_p
|
33 |
+
self.top_k = top_k
|
34 |
+
self.tokenizer = tokenizer
|
35 |
+
|
36 |
+
for k, v in kwargs.items():
|
37 |
+
setattr(self, k, v)
|
38 |
+
|
39 |
+
@abc.abstractmethod
|
40 |
+
async def generate_answer(
|
41 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
42 |
+
) -> str:
|
43 |
+
"""Generate answer from the model."""
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
@abc.abstractmethod
|
47 |
+
async def generate_topk_per_token(
|
48 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
49 |
+
) -> List[Token]:
|
50 |
+
"""Generate top-k tokens for the next token prediction."""
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
@abc.abstractmethod
|
54 |
+
async def generate_inputs_prob(
|
55 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
56 |
+
) -> List[Token]:
|
57 |
+
"""Generate probabilities for each token in the input."""
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
def count_tokens(self, text: str) -> int:
|
61 |
+
"""Count the number of tokens in the text."""
|
62 |
+
if self.tokenizer is None:
|
63 |
+
raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.")
|
64 |
+
return len(self.tokenizer.encode(text))
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def filter_think_tags(text: str, think_tag: str = "think") -> str:
|
68 |
+
"""
|
69 |
+
Remove <think> tags from the text.
|
70 |
+
If the text contains <think> and </think>, it removes everything between them and the tags themselves.
|
71 |
+
"""
|
72 |
+
think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
|
73 |
+
filtered_text = think_pattern.sub("", text).strip()
|
74 |
+
return filtered_text if filtered_text else text.strip()
|
graphgen/bases/base_tokenizer.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class BaseTokenizer(ABC):
|
10 |
+
model_name: str = "cl100k_base"
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def encode(self, text: str) -> List[int]:
|
14 |
+
"""Encode text -> token ids."""
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def decode(self, token_ids: List[int]) -> str:
|
19 |
+
"""Decode token ids -> text."""
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def count_tokens(self, text: str) -> int:
|
23 |
+
return len(self.encode(text))
|
24 |
+
|
25 |
+
def chunk_by_token_size(
|
26 |
+
self,
|
27 |
+
content: str,
|
28 |
+
*,
|
29 |
+
overlap_token_size: int = 128,
|
30 |
+
max_token_size: int = 1024,
|
31 |
+
) -> List[dict]:
|
32 |
+
tokens = self.encode(content)
|
33 |
+
results = []
|
34 |
+
step = max_token_size - overlap_token_size
|
35 |
+
for index, start in enumerate(range(0, len(tokens), step)):
|
36 |
+
chunk_ids = tokens[start : start + max_token_size]
|
37 |
+
results.append(
|
38 |
+
{
|
39 |
+
"tokens": len(chunk_ids),
|
40 |
+
"content": self.decode(chunk_ids).strip(),
|
41 |
+
"chunk_order_index": index,
|
42 |
+
}
|
43 |
+
)
|
44 |
+
return results
|
graphgen/bases/datatypes.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
from dataclasses import dataclass, field
|
|
|
2 |
|
3 |
|
4 |
@dataclass
|
@@ -16,3 +18,15 @@ class QAPair:
|
|
16 |
|
17 |
question: str
|
18 |
answer: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
from dataclasses import dataclass, field
|
3 |
+
from typing import List, Union
|
4 |
|
5 |
|
6 |
@dataclass
|
|
|
18 |
|
19 |
question: str
|
20 |
answer: str
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class Token:
|
25 |
+
text: str
|
26 |
+
prob: float
|
27 |
+
top_candidates: List = field(default_factory=list)
|
28 |
+
ppl: Union[float, None] = field(default=None)
|
29 |
+
|
30 |
+
@property
|
31 |
+
def logprob(self) -> float:
|
32 |
+
return math.log(self.prob)
|
graphgen/graphgen.py
CHANGED
@@ -2,10 +2,9 @@ import asyncio
|
|
2 |
import os
|
3 |
import time
|
4 |
from dataclasses import dataclass, field
|
5 |
-
from typing import Dict,
|
6 |
|
7 |
import gradio as gr
|
8 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
9 |
|
10 |
from graphgen.bases.base_storage import StorageNameSpace
|
11 |
from graphgen.bases.datatypes import Chunk
|
@@ -13,27 +12,25 @@ from graphgen.models import (
|
|
13 |
JsonKVStorage,
|
14 |
JsonListStorage,
|
15 |
NetworkXStorage,
|
16 |
-
|
17 |
Tokenizer,
|
18 |
TraverseStrategy,
|
19 |
-
read_file,
|
20 |
-
split_chunks,
|
21 |
)
|
22 |
-
|
23 |
-
|
24 |
extract_kg,
|
25 |
generate_cot,
|
26 |
judge_statement,
|
27 |
quiz,
|
|
|
28 |
search_all,
|
29 |
traverse_graph_for_aggregated,
|
30 |
traverse_graph_for_atomic,
|
31 |
traverse_graph_for_multi_hop,
|
32 |
)
|
33 |
-
from .utils import (
|
|
|
34 |
compute_content_hash,
|
35 |
-
create_event_loop,
|
36 |
-
detect_main_language,
|
37 |
format_generation_results,
|
38 |
logger,
|
39 |
)
|
@@ -49,8 +46,8 @@ class GraphGen:
|
|
49 |
|
50 |
# llm
|
51 |
tokenizer_instance: Tokenizer = None
|
52 |
-
synthesizer_llm_client:
|
53 |
-
trainee_llm_client:
|
54 |
|
55 |
# search
|
56 |
search_config: dict = field(
|
@@ -67,17 +64,17 @@ class GraphGen:
|
|
67 |
self.tokenizer_instance: Tokenizer = Tokenizer(
|
68 |
model_name=self.config["tokenizer"]
|
69 |
)
|
70 |
-
self.synthesizer_llm_client:
|
71 |
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
72 |
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
73 |
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
74 |
-
|
75 |
)
|
76 |
-
self.trainee_llm_client:
|
77 |
model_name=os.getenv("TRAINEE_MODEL"),
|
78 |
api_key=os.getenv("TRAINEE_API_KEY"),
|
79 |
base_url=os.getenv("TRAINEE_BASE_URL"),
|
80 |
-
|
81 |
)
|
82 |
self.search_config = self.config["search"]
|
83 |
|
@@ -111,15 +108,23 @@ class GraphGen:
|
|
111 |
namespace="qa",
|
112 |
)
|
113 |
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
if len(data) == 0:
|
117 |
-
|
|
|
118 |
|
119 |
-
|
120 |
-
assert isinstance(data, list) and isinstance(data[0], dict)
|
121 |
|
122 |
-
#
|
|
|
123 |
new_docs = {
|
124 |
compute_content_hash(doc["content"], prefix="doc-"): {
|
125 |
"content": doc["content"]
|
@@ -128,38 +133,19 @@ class GraphGen:
|
|
128 |
}
|
129 |
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
|
130 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
|
|
131 |
if len(new_docs) == 0:
|
132 |
logger.warning("All docs are already in the storage")
|
133 |
-
return
|
134 |
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
doc["content"],
|
144 |
-
language=doc_language,
|
145 |
-
chunk_size=self.config["split"]["chunk_size"],
|
146 |
-
chunk_overlap=self.config["split"]["chunk_overlap"],
|
147 |
-
)
|
148 |
-
|
149 |
-
chunks = {
|
150 |
-
compute_content_hash(txt, prefix="chunk-"): {
|
151 |
-
"content": txt,
|
152 |
-
"full_doc_id": doc_key,
|
153 |
-
"length": len(self.tokenizer_instance.encode_string(txt)),
|
154 |
-
"language": doc_language,
|
155 |
-
}
|
156 |
-
for txt in text_chunks
|
157 |
-
}
|
158 |
-
inserting_chunks.update(chunks)
|
159 |
-
|
160 |
-
if self.progress_bar is not None:
|
161 |
-
self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
|
162 |
-
cur_index += 1
|
163 |
|
164 |
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
165 |
list(inserting_chunks.keys())
|
@@ -167,29 +153,16 @@ class GraphGen:
|
|
167 |
inserting_chunks = {
|
168 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
169 |
}
|
170 |
-
await self.full_docs_storage.upsert(new_docs)
|
171 |
-
await self.text_chunks_storage.upsert(inserting_chunks)
|
172 |
-
|
173 |
-
return inserting_chunks
|
174 |
-
|
175 |
-
def insert(self):
|
176 |
-
loop = create_event_loop()
|
177 |
-
loop.run_until_complete(self.async_insert())
|
178 |
-
|
179 |
-
async def async_insert(self):
|
180 |
-
"""
|
181 |
-
insert chunks into the graph
|
182 |
-
"""
|
183 |
-
|
184 |
-
input_file = self.config["read"]["input_file"]
|
185 |
-
data = read_file(input_file)
|
186 |
-
inserting_chunks = await self.async_split_chunks(data)
|
187 |
|
188 |
if len(inserting_chunks) == 0:
|
189 |
logger.warning("All chunks are already in the storage")
|
190 |
return
|
|
|
191 |
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
|
|
|
|
|
192 |
|
|
|
193 |
logger.info("[Entity and Relation Extraction]...")
|
194 |
_add_entities_and_relations = await extract_kg(
|
195 |
llm_client=self.synthesizer_llm_client,
|
@@ -219,11 +192,8 @@ class GraphGen:
|
|
219 |
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
|
220 |
await asyncio.gather(*tasks)
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
loop.run_until_complete(self.async_search())
|
225 |
-
|
226 |
-
async def async_search(self):
|
227 |
logger.info(
|
228 |
"Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
|
229 |
)
|
@@ -257,13 +227,10 @@ class GraphGen:
|
|
257 |
]
|
258 |
)
|
259 |
# TODO: fix insert after search
|
260 |
-
await self.
|
261 |
-
|
262 |
-
def quiz(self):
|
263 |
-
loop = create_event_loop()
|
264 |
-
loop.run_until_complete(self.async_quiz())
|
265 |
|
266 |
-
|
|
|
267 |
max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
|
268 |
await quiz(
|
269 |
self.synthesizer_llm_client,
|
@@ -273,11 +240,8 @@ class GraphGen:
|
|
273 |
)
|
274 |
await self.rephrase_storage.index_done_callback()
|
275 |
|
276 |
-
|
277 |
-
|
278 |
-
loop.run_until_complete(self.async_judge())
|
279 |
-
|
280 |
-
async def async_judge(self):
|
281 |
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
|
282 |
_update_relations = await judge_statement(
|
283 |
self.trainee_llm_client,
|
@@ -287,11 +251,8 @@ class GraphGen:
|
|
287 |
)
|
288 |
await _update_relations.index_done_callback()
|
289 |
|
290 |
-
|
291 |
-
|
292 |
-
loop.run_until_complete(self.async_traverse())
|
293 |
-
|
294 |
-
async def async_traverse(self):
|
295 |
output_data_type = self.config["output_data_type"]
|
296 |
|
297 |
if output_data_type == "atomic":
|
@@ -331,11 +292,8 @@ class GraphGen:
|
|
331 |
await self.qa_storage.upsert(results)
|
332 |
await self.qa_storage.index_done_callback()
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
loop.run_until_complete(self.async_generate_reasoning(method_params))
|
337 |
-
|
338 |
-
async def async_generate_reasoning(self, method_params):
|
339 |
results = await generate_cot(
|
340 |
self.graph_storage,
|
341 |
self.synthesizer_llm_client,
|
@@ -349,11 +307,8 @@ class GraphGen:
|
|
349 |
await self.qa_storage.upsert(results)
|
350 |
await self.qa_storage.index_done_callback()
|
351 |
|
352 |
-
|
353 |
-
|
354 |
-
loop.run_until_complete(self.async_clear())
|
355 |
-
|
356 |
-
async def async_clear(self):
|
357 |
await self.full_docs_storage.drop()
|
358 |
await self.text_chunks_storage.drop()
|
359 |
await self.search_storage.drop()
|
|
|
2 |
import os
|
3 |
import time
|
4 |
from dataclasses import dataclass, field
|
5 |
+
from typing import Dict, cast
|
6 |
|
7 |
import gradio as gr
|
|
|
8 |
|
9 |
from graphgen.bases.base_storage import StorageNameSpace
|
10 |
from graphgen.bases.datatypes import Chunk
|
|
|
12 |
JsonKVStorage,
|
13 |
JsonListStorage,
|
14 |
NetworkXStorage,
|
15 |
+
OpenAIClient,
|
16 |
Tokenizer,
|
17 |
TraverseStrategy,
|
|
|
|
|
18 |
)
|
19 |
+
from graphgen.operators import (
|
20 |
+
chunk_documents,
|
21 |
extract_kg,
|
22 |
generate_cot,
|
23 |
judge_statement,
|
24 |
quiz,
|
25 |
+
read_files,
|
26 |
search_all,
|
27 |
traverse_graph_for_aggregated,
|
28 |
traverse_graph_for_atomic,
|
29 |
traverse_graph_for_multi_hop,
|
30 |
)
|
31 |
+
from graphgen.utils import (
|
32 |
+
async_to_sync_method,
|
33 |
compute_content_hash,
|
|
|
|
|
34 |
format_generation_results,
|
35 |
logger,
|
36 |
)
|
|
|
46 |
|
47 |
# llm
|
48 |
tokenizer_instance: Tokenizer = None
|
49 |
+
synthesizer_llm_client: OpenAIClient = None
|
50 |
+
trainee_llm_client: OpenAIClient = None
|
51 |
|
52 |
# search
|
53 |
search_config: dict = field(
|
|
|
64 |
self.tokenizer_instance: Tokenizer = Tokenizer(
|
65 |
model_name=self.config["tokenizer"]
|
66 |
)
|
67 |
+
self.synthesizer_llm_client: OpenAIClient = OpenAIClient(
|
68 |
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
69 |
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
70 |
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
71 |
+
tokenizer=self.tokenizer_instance,
|
72 |
)
|
73 |
+
self.trainee_llm_client: OpenAIClient = OpenAIClient(
|
74 |
model_name=os.getenv("TRAINEE_MODEL"),
|
75 |
api_key=os.getenv("TRAINEE_API_KEY"),
|
76 |
base_url=os.getenv("TRAINEE_BASE_URL"),
|
77 |
+
tokenizer=self.tokenizer_instance,
|
78 |
)
|
79 |
self.search_config = self.config["search"]
|
80 |
|
|
|
108 |
namespace="qa",
|
109 |
)
|
110 |
|
111 |
+
@async_to_sync_method
|
112 |
+
async def insert(self):
|
113 |
+
"""
|
114 |
+
insert chunks into the graph
|
115 |
+
"""
|
116 |
+
input_file = self.config["read"]["input_file"]
|
117 |
+
|
118 |
+
# Step 1: Read files
|
119 |
+
data = read_files(input_file)
|
120 |
if len(data) == 0:
|
121 |
+
logger.warning("No data to process")
|
122 |
+
return
|
123 |
|
124 |
+
# TODO: configurable whether to use coreference resolution
|
|
|
125 |
|
126 |
+
# Step 2: Split chunks and filter existing ones
|
127 |
+
assert isinstance(data, list) and isinstance(data[0], dict)
|
128 |
new_docs = {
|
129 |
compute_content_hash(doc["content"], prefix="doc-"): {
|
130 |
"content": doc["content"]
|
|
|
133 |
}
|
134 |
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
|
135 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
136 |
+
|
137 |
if len(new_docs) == 0:
|
138 |
logger.warning("All docs are already in the storage")
|
139 |
+
return
|
140 |
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
141 |
|
142 |
+
inserting_chunks = await chunk_documents(
|
143 |
+
new_docs,
|
144 |
+
self.config["split"]["chunk_size"],
|
145 |
+
self.config["split"]["chunk_overlap"],
|
146 |
+
self.tokenizer_instance,
|
147 |
+
self.progress_bar,
|
148 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
151 |
list(inserting_chunks.keys())
|
|
|
153 |
inserting_chunks = {
|
154 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
155 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
if len(inserting_chunks) == 0:
|
158 |
logger.warning("All chunks are already in the storage")
|
159 |
return
|
160 |
+
|
161 |
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
|
162 |
+
await self.full_docs_storage.upsert(new_docs)
|
163 |
+
await self.text_chunks_storage.upsert(inserting_chunks)
|
164 |
|
165 |
+
# Step 3: Extract entities and relations from chunks
|
166 |
logger.info("[Entity and Relation Extraction]...")
|
167 |
_add_entities_and_relations = await extract_kg(
|
168 |
llm_client=self.synthesizer_llm_client,
|
|
|
192 |
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
|
193 |
await asyncio.gather(*tasks)
|
194 |
|
195 |
+
@async_to_sync_method
|
196 |
+
async def search(self):
|
|
|
|
|
|
|
197 |
logger.info(
|
198 |
"Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
|
199 |
)
|
|
|
227 |
]
|
228 |
)
|
229 |
# TODO: fix insert after search
|
230 |
+
await self.insert()
|
|
|
|
|
|
|
|
|
231 |
|
232 |
+
@async_to_sync_method
|
233 |
+
async def quiz(self):
|
234 |
max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
|
235 |
await quiz(
|
236 |
self.synthesizer_llm_client,
|
|
|
240 |
)
|
241 |
await self.rephrase_storage.index_done_callback()
|
242 |
|
243 |
+
@async_to_sync_method
|
244 |
+
async def judge(self):
|
|
|
|
|
|
|
245 |
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
|
246 |
_update_relations = await judge_statement(
|
247 |
self.trainee_llm_client,
|
|
|
251 |
)
|
252 |
await _update_relations.index_done_callback()
|
253 |
|
254 |
+
@async_to_sync_method
|
255 |
+
async def traverse(self):
|
|
|
|
|
|
|
256 |
output_data_type = self.config["output_data_type"]
|
257 |
|
258 |
if output_data_type == "atomic":
|
|
|
292 |
await self.qa_storage.upsert(results)
|
293 |
await self.qa_storage.index_done_callback()
|
294 |
|
295 |
+
@async_to_sync_method
|
296 |
+
async def generate_reasoning(self, method_params):
|
|
|
|
|
|
|
297 |
results = await generate_cot(
|
298 |
self.graph_storage,
|
299 |
self.synthesizer_llm_client,
|
|
|
307 |
await self.qa_storage.upsert(results)
|
308 |
await self.qa_storage.index_done_callback()
|
309 |
|
310 |
+
@async_to_sync_method
|
311 |
+
async def clear(self):
|
|
|
|
|
|
|
312 |
await self.full_docs_storage.drop()
|
313 |
await self.text_chunks_storage.drop()
|
314 |
await self.search_storage.drop()
|
graphgen/models/__init__.py
CHANGED
@@ -3,15 +3,15 @@ from .evaluate.length_evaluator import LengthEvaluator
|
|
3 |
from .evaluate.mtld_evaluator import MTLDEvaluator
|
4 |
from .evaluate.reward_evaluator import RewardEvaluator
|
5 |
from .evaluate.uni_evaluator import UniEvaluator
|
6 |
-
from .llm.
|
7 |
-
from .llm.
|
8 |
-
from .
|
9 |
-
from .reader import read_file
|
10 |
from .search.db.uniprot_search import UniProtSearch
|
11 |
from .search.kg.wiki_search import WikiSearch
|
12 |
from .search.web.bing_search import BingSearch
|
13 |
from .search.web.google_search import GoogleSearch
|
14 |
-
from .splitter import
|
15 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
16 |
from .storage.networkx_storage import NetworkXStorage
|
17 |
from .strategy.travserse_strategy import TraverseStrategy
|
|
|
|
3 |
from .evaluate.mtld_evaluator import MTLDEvaluator
|
4 |
from .evaluate.reward_evaluator import RewardEvaluator
|
5 |
from .evaluate.uni_evaluator import UniEvaluator
|
6 |
+
from .llm.openai_client import OpenAIClient
|
7 |
+
from .llm.topk_token_model import TopkTokenModel
|
8 |
+
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
|
|
|
9 |
from .search.db.uniprot_search import UniProtSearch
|
10 |
from .search.kg.wiki_search import WikiSearch
|
11 |
from .search.web.bing_search import BingSearch
|
12 |
from .search.web.google_search import GoogleSearch
|
13 |
+
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
14 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
15 |
from .storage.networkx_storage import NetworkXStorage
|
16 |
from .strategy.travserse_strategy import TraverseStrategy
|
17 |
+
from .tokenizer import Tokenizer
|
graphgen/models/evaluate/length_evaluator.py
CHANGED
@@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|
2 |
|
3 |
from graphgen.bases.datatypes import QAPair
|
4 |
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
5 |
-
from graphgen.models.
|
6 |
from graphgen.utils import create_event_loop
|
7 |
|
8 |
|
@@ -18,5 +18,5 @@ class LengthEvaluator(BaseEvaluator):
|
|
18 |
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
|
19 |
|
20 |
def _calculate_length(self, text: str) -> float:
|
21 |
-
tokens = self.tokenizer.
|
22 |
return len(tokens)
|
|
|
2 |
|
3 |
from graphgen.bases.datatypes import QAPair
|
4 |
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
5 |
+
from graphgen.models.tokenizer import Tokenizer
|
6 |
from graphgen.utils import create_event_loop
|
7 |
|
8 |
|
|
|
18 |
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
|
19 |
|
20 |
def _calculate_length(self, text: str) -> float:
|
21 |
+
tokens = self.tokenizer.encode(text)
|
22 |
return len(tokens)
|
graphgen/models/kg_builder/NetworkXKGBuilder.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from graphgen.bases import BaseKGBuilder
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class NetworkXKGBuilder(BaseKGBuilder):
|
8 |
+
def build(self, chunks):
|
9 |
+
pass
|
10 |
+
|
11 |
+
async def extract_all(self, chunks):
|
12 |
+
pass
|
13 |
+
|
14 |
+
async def extract(self, chunk):
|
15 |
+
pass
|
16 |
+
|
17 |
+
async def merge_nodes(self, nodes_data, kg_instance, llm):
|
18 |
+
pass
|
graphgen/{operators/kg → models/kg_builder}/__init__.py
RENAMED
File without changes
|
graphgen/models/llm/ollama_client.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: implement ollama client
|
2 |
+
from typing import Any, List, Optional
|
3 |
+
|
4 |
+
from graphgen.bases import BaseLLMClient, Token
|
5 |
+
|
6 |
+
|
7 |
+
class OllamaClient(BaseLLMClient):
|
8 |
+
async def generate_answer(
|
9 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
10 |
+
) -> str:
|
11 |
+
pass
|
12 |
+
|
13 |
+
async def generate_topk_per_token(
|
14 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
15 |
+
) -> List[Token]:
|
16 |
+
pass
|
17 |
+
|
18 |
+
async def generate_inputs_prob(
|
19 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
20 |
+
) -> List[Token]:
|
21 |
+
pass
|
graphgen/models/llm/{openai_model.py → openai_client.py}
RENAMED
@@ -1,7 +1,5 @@
|
|
1 |
import math
|
2 |
-
import
|
3 |
-
from dataclasses import dataclass, field
|
4 |
-
from typing import Dict, List, Optional
|
5 |
|
6 |
import openai
|
7 |
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
|
@@ -12,9 +10,9 @@ from tenacity import (
|
|
12 |
wait_exponential,
|
13 |
)
|
14 |
|
|
|
|
|
15 |
from graphgen.models.llm.limitter import RPM, TPM
|
16 |
-
from graphgen.models.llm.tokenizer import Tokenizer
|
17 |
-
from graphgen.models.llm.topk_token_model import Token, TopkTokenModel
|
18 |
|
19 |
|
20 |
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
@@ -30,32 +28,33 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
|
30 |
return tokens
|
31 |
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
|
60 |
def __post_init__(self):
|
61 |
assert self.api_key is not None, "Please provide api key to access openai api."
|
@@ -66,7 +65,7 @@ class OpenAIModel(TopkTokenModel):
|
|
66 |
def _pre_generate(self, text: str, history: List[str]) -> Dict:
|
67 |
kwargs = {
|
68 |
"temperature": self.temperature,
|
69 |
-
"top_p": self.
|
70 |
"max_tokens": self.max_tokens,
|
71 |
}
|
72 |
if self.seed:
|
@@ -94,7 +93,10 @@ class OpenAIModel(TopkTokenModel):
|
|
94 |
),
|
95 |
)
|
96 |
async def generate_topk_per_token(
|
97 |
-
self,
|
|
|
|
|
|
|
98 |
) -> List[Token]:
|
99 |
kwargs = self._pre_generate(text, history)
|
100 |
if self.topk_per_token > 0:
|
@@ -120,16 +122,16 @@ class OpenAIModel(TopkTokenModel):
|
|
120 |
),
|
121 |
)
|
122 |
async def generate_answer(
|
123 |
-
self,
|
|
|
|
|
|
|
124 |
) -> str:
|
125 |
kwargs = self._pre_generate(text, history)
|
126 |
-
kwargs["temperature"] = temperature
|
127 |
|
128 |
prompt_tokens = 0
|
129 |
for message in kwargs["messages"]:
|
130 |
-
prompt_tokens += len(
|
131 |
-
self.tokenizer_instance.encode_string(message["content"])
|
132 |
-
)
|
133 |
estimated_tokens = prompt_tokens + kwargs["max_tokens"]
|
134 |
|
135 |
if self.request_limit:
|
@@ -147,9 +149,10 @@ class OpenAIModel(TopkTokenModel):
|
|
147 |
"total_tokens": completion.usage.total_tokens,
|
148 |
}
|
149 |
)
|
150 |
-
return filter_think_tags(completion.choices[0].message.content)
|
151 |
|
152 |
async def generate_inputs_prob(
|
153 |
-
self, text: str, history: Optional[List[str]] = None
|
154 |
) -> List[Token]:
|
|
|
155 |
raise NotImplementedError
|
|
|
1 |
import math
|
2 |
+
from typing import Any, Dict, List, Optional
|
|
|
|
|
3 |
|
4 |
import openai
|
5 |
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
|
|
|
10 |
wait_exponential,
|
11 |
)
|
12 |
|
13 |
+
from graphgen.bases.base_llm_client import BaseLLMClient
|
14 |
+
from graphgen.bases.datatypes import Token
|
15 |
from graphgen.models.llm.limitter import RPM, TPM
|
|
|
|
|
16 |
|
17 |
|
18 |
def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
|
|
|
28 |
return tokens
|
29 |
|
30 |
|
31 |
+
class OpenAIClient(BaseLLMClient):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
*,
|
35 |
+
model_name: str = "gpt-4o-mini",
|
36 |
+
api_key: Optional[str] = None,
|
37 |
+
base_url: Optional[str] = None,
|
38 |
+
json_mode: bool = False,
|
39 |
+
seed: Optional[int] = None,
|
40 |
+
topk_per_token: int = 5, # number of topk tokens to generate for each token
|
41 |
+
request_limit: bool = False,
|
42 |
+
**kwargs: Any,
|
43 |
+
):
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
self.model_name = model_name
|
46 |
+
self.api_key = api_key
|
47 |
+
self.base_url = base_url
|
48 |
+
self.json_mode = json_mode
|
49 |
+
self.seed = seed
|
50 |
+
self.topk_per_token = topk_per_token
|
51 |
+
|
52 |
+
self.token_usage: list = []
|
53 |
+
self.request_limit = request_limit
|
54 |
+
self.rpm = RPM(rpm=1000)
|
55 |
+
self.tpm = TPM(tpm=50000)
|
56 |
+
|
57 |
+
self.__post_init__()
|
58 |
|
59 |
def __post_init__(self):
|
60 |
assert self.api_key is not None, "Please provide api key to access openai api."
|
|
|
65 |
def _pre_generate(self, text: str, history: List[str]) -> Dict:
|
66 |
kwargs = {
|
67 |
"temperature": self.temperature,
|
68 |
+
"top_p": self.top_p,
|
69 |
"max_tokens": self.max_tokens,
|
70 |
}
|
71 |
if self.seed:
|
|
|
93 |
),
|
94 |
)
|
95 |
async def generate_topk_per_token(
|
96 |
+
self,
|
97 |
+
text: str,
|
98 |
+
history: Optional[List[str]] = None,
|
99 |
+
**extra: Any,
|
100 |
) -> List[Token]:
|
101 |
kwargs = self._pre_generate(text, history)
|
102 |
if self.topk_per_token > 0:
|
|
|
122 |
),
|
123 |
)
|
124 |
async def generate_answer(
|
125 |
+
self,
|
126 |
+
text: str,
|
127 |
+
history: Optional[List[str]] = None,
|
128 |
+
**extra: Any,
|
129 |
) -> str:
|
130 |
kwargs = self._pre_generate(text, history)
|
|
|
131 |
|
132 |
prompt_tokens = 0
|
133 |
for message in kwargs["messages"]:
|
134 |
+
prompt_tokens += len(self.tokenizer.encode(message["content"]))
|
|
|
|
|
135 |
estimated_tokens = prompt_tokens + kwargs["max_tokens"]
|
136 |
|
137 |
if self.request_limit:
|
|
|
149 |
"total_tokens": completion.usage.total_tokens,
|
150 |
}
|
151 |
)
|
152 |
+
return self.filter_think_tags(completion.choices[0].message.content)
|
153 |
|
154 |
async def generate_inputs_prob(
|
155 |
+
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
156 |
) -> List[Token]:
|
157 |
+
"""Generate probabilities for each token in the input."""
|
158 |
raise NotImplementedError
|
graphgen/models/llm/tokenizer.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import List
|
3 |
-
import tiktoken
|
4 |
-
|
5 |
-
try:
|
6 |
-
from transformers import AutoTokenizer
|
7 |
-
TRANSFORMERS_AVAILABLE = True
|
8 |
-
except ImportError:
|
9 |
-
AutoTokenizer = None
|
10 |
-
TRANSFORMERS_AVAILABLE = False
|
11 |
-
|
12 |
-
|
13 |
-
def get_tokenizer(tokenizer_name: str = "cl100k_base"):
|
14 |
-
"""
|
15 |
-
Get a tokenizer instance by name.
|
16 |
-
|
17 |
-
:param tokenizer_name: tokenizer name, tiktoken encoding name or Hugging Face model name
|
18 |
-
:return: tokenizer instance
|
19 |
-
"""
|
20 |
-
if tokenizer_name in tiktoken.list_encoding_names():
|
21 |
-
return tiktoken.get_encoding(tokenizer_name)
|
22 |
-
if TRANSFORMERS_AVAILABLE:
|
23 |
-
try:
|
24 |
-
return AutoTokenizer.from_pretrained(tokenizer_name)
|
25 |
-
except Exception as e:
|
26 |
-
raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e
|
27 |
-
else:
|
28 |
-
raise ValueError("Hugging Face Transformers is not available, please install it first.")
|
29 |
-
|
30 |
-
@dataclass
|
31 |
-
class Tokenizer:
|
32 |
-
model_name: str = "cl100k_base"
|
33 |
-
|
34 |
-
def __post_init__(self):
|
35 |
-
self.tokenizer = get_tokenizer(self.model_name)
|
36 |
-
|
37 |
-
def encode_string(self, text: str) -> List[int]:
|
38 |
-
"""
|
39 |
-
Encode text to tokens
|
40 |
-
|
41 |
-
:param text
|
42 |
-
:return: tokens
|
43 |
-
"""
|
44 |
-
return self.tokenizer.encode(text)
|
45 |
-
|
46 |
-
def decode_tokens(self, tokens: List[int]) -> str:
|
47 |
-
"""
|
48 |
-
Decode tokens to text
|
49 |
-
|
50 |
-
:param tokens
|
51 |
-
:return: text
|
52 |
-
"""
|
53 |
-
return self.tokenizer.decode(tokens)
|
54 |
-
|
55 |
-
def chunk_by_token_size(
|
56 |
-
self, content: str, overlap_token_size=128, max_token_size=1024
|
57 |
-
):
|
58 |
-
tokens = self.encode_string(content)
|
59 |
-
results = []
|
60 |
-
for index, start in enumerate(
|
61 |
-
range(0, len(tokens), max_token_size - overlap_token_size)
|
62 |
-
):
|
63 |
-
chunk_content = self.decode_tokens(
|
64 |
-
tokens[start : start + max_token_size]
|
65 |
-
)
|
66 |
-
results.append(
|
67 |
-
{
|
68 |
-
"tokens": min(max_token_size, len(tokens) - start),
|
69 |
-
"content": chunk_content.strip(),
|
70 |
-
"chunk_order_index": index,
|
71 |
-
}
|
72 |
-
)
|
73 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/llm/topk_token_model.py
CHANGED
@@ -1,18 +1,7 @@
|
|
1 |
-
import
|
2 |
-
from
|
3 |
-
from typing import List, Union, Optional
|
4 |
|
5 |
-
|
6 |
-
@dataclass
|
7 |
-
class Token:
|
8 |
-
text: str
|
9 |
-
prob: float
|
10 |
-
top_candidates: List = field(default_factory=list)
|
11 |
-
ppl: Union[float, None] = field(default=None)
|
12 |
-
|
13 |
-
@property
|
14 |
-
def logprob(self) -> float:
|
15 |
-
return math.log(self.prob)
|
16 |
|
17 |
|
18 |
@dataclass
|
@@ -34,14 +23,18 @@ class TopkTokenModel:
|
|
34 |
"""
|
35 |
raise NotImplementedError
|
36 |
|
37 |
-
async def generate_inputs_prob(
|
|
|
|
|
38 |
"""
|
39 |
Generate prob and text for each token of the input text.
|
40 |
This function is used to visualize the ppl.
|
41 |
"""
|
42 |
raise NotImplementedError
|
43 |
|
44 |
-
async def generate_answer(
|
|
|
|
|
45 |
"""
|
46 |
Generate answer from the model.
|
47 |
"""
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional
|
|
|
3 |
|
4 |
+
from graphgen.bases import Token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
@dataclass
|
|
|
23 |
"""
|
24 |
raise NotImplementedError
|
25 |
|
26 |
+
async def generate_inputs_prob(
|
27 |
+
self, text: str, history: Optional[List[str]] = None
|
28 |
+
) -> List[Token]:
|
29 |
"""
|
30 |
Generate prob and text for each token of the input text.
|
31 |
This function is used to visualize the ppl.
|
32 |
"""
|
33 |
raise NotImplementedError
|
34 |
|
35 |
+
async def generate_answer(
|
36 |
+
self, text: str, history: Optional[List[str]] = None
|
37 |
+
) -> str:
|
38 |
"""
|
39 |
Generate answer from the model.
|
40 |
"""
|
graphgen/models/reader/__init__.py
CHANGED
@@ -2,21 +2,3 @@ from .csv_reader import CsvReader
|
|
2 |
from .json_reader import JsonReader
|
3 |
from .jsonl_reader import JsonlReader
|
4 |
from .txt_reader import TxtReader
|
5 |
-
|
6 |
-
_MAPPING = {
|
7 |
-
"jsonl": JsonlReader,
|
8 |
-
"json": JsonReader,
|
9 |
-
"txt": TxtReader,
|
10 |
-
"csv": CsvReader,
|
11 |
-
}
|
12 |
-
|
13 |
-
|
14 |
-
def read_file(file_path: str):
|
15 |
-
suffix = file_path.split(".")[-1]
|
16 |
-
if suffix in _MAPPING:
|
17 |
-
reader = _MAPPING[suffix]()
|
18 |
-
else:
|
19 |
-
raise ValueError(
|
20 |
-
f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}"
|
21 |
-
)
|
22 |
-
return reader.read(file_path)
|
|
|
2 |
from .json_reader import JsonReader
|
3 |
from .jsonl_reader import JsonlReader
|
4 |
from .txt_reader import TxtReader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/splitter/__init__.py
CHANGED
@@ -1,31 +1,4 @@
|
|
1 |
-
from functools import lru_cache
|
2 |
-
from typing import Union
|
3 |
-
|
4 |
from .recursive_character_splitter import (
|
5 |
ChineseRecursiveTextSplitter,
|
6 |
RecursiveCharacterSplitter,
|
7 |
)
|
8 |
-
|
9 |
-
_MAPPING = {
|
10 |
-
"en": RecursiveCharacterSplitter,
|
11 |
-
"zh": ChineseRecursiveTextSplitter,
|
12 |
-
}
|
13 |
-
|
14 |
-
SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
|
15 |
-
|
16 |
-
|
17 |
-
@lru_cache(maxsize=None)
|
18 |
-
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
|
19 |
-
cls = _MAPPING[language]
|
20 |
-
kwargs = dict(frozen_kwargs)
|
21 |
-
return cls(**kwargs)
|
22 |
-
|
23 |
-
|
24 |
-
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
|
25 |
-
if language not in _MAPPING:
|
26 |
-
raise ValueError(
|
27 |
-
f"Unsupported language: {language}. "
|
28 |
-
f"Supported languages are: {list(_MAPPING.keys())}"
|
29 |
-
)
|
30 |
-
splitter = _get_splitter(language, frozenset(kwargs.items()))
|
31 |
-
return splitter.split_text(text)
|
|
|
|
|
|
|
|
|
1 |
from .recursive_character_splitter import (
|
2 |
ChineseRecursiveTextSplitter,
|
3 |
RecursiveCharacterSplitter,
|
4 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/tokenizer/__init__.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from graphgen.bases import BaseTokenizer
|
5 |
+
|
6 |
+
from .hf_tokenizer import HFTokenizer
|
7 |
+
from .tiktoken_tokenizer import TiktokenTokenizer
|
8 |
+
|
9 |
+
try:
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
|
12 |
+
_HF_AVAILABLE = True
|
13 |
+
except ImportError:
|
14 |
+
_HF_AVAILABLE = False
|
15 |
+
|
16 |
+
|
17 |
+
def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer:
|
18 |
+
import tiktoken
|
19 |
+
|
20 |
+
if tokenizer_name in tiktoken.list_encoding_names():
|
21 |
+
return TiktokenTokenizer(model_name=tokenizer_name)
|
22 |
+
|
23 |
+
# 2. HuggingFace
|
24 |
+
if _HF_AVAILABLE:
|
25 |
+
return HFTokenizer(model_name=tokenizer_name)
|
26 |
+
|
27 |
+
raise ValueError(
|
28 |
+
f"Unknown tokenizer {tokenizer_name} and HuggingFace not available."
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class Tokenizer(BaseTokenizer):
|
34 |
+
"""
|
35 |
+
Encapsulates different tokenization implementations based on the specified model name.
|
36 |
+
"""
|
37 |
+
|
38 |
+
model_name: str = "cl100k_base"
|
39 |
+
_impl: BaseTokenizer = field(init=False, repr=False)
|
40 |
+
|
41 |
+
def __post_init__(self):
|
42 |
+
self._impl = get_tokenizer_impl(self.model_name)
|
43 |
+
|
44 |
+
def encode(self, text: str) -> List[int]:
|
45 |
+
return self._impl.encode(text)
|
46 |
+
|
47 |
+
def decode(self, token_ids: List[int]) -> str:
|
48 |
+
return self._impl.decode(token_ids)
|
49 |
+
|
50 |
+
def count_tokens(self, text: str) -> int:
|
51 |
+
return self._impl.count_tokens(text)
|
graphgen/models/tokenizer/hf_tokenizer.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
|
6 |
+
from graphgen.bases import BaseTokenizer
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class HFTokenizer(BaseTokenizer):
|
11 |
+
def __post_init__(self):
|
12 |
+
self.enc = AutoTokenizer.from_pretrained(self.model_name)
|
13 |
+
|
14 |
+
def encode(self, text: str) -> List[int]:
|
15 |
+
return self.enc.encode(text, add_special_tokens=False)
|
16 |
+
|
17 |
+
def decode(self, token_ids: List[int]) -> str:
|
18 |
+
return self.enc.decode(token_ids, skip_special_tokens=True)
|
graphgen/models/tokenizer/tiktoken_tokenizer.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import tiktoken
|
5 |
+
|
6 |
+
from graphgen.bases import BaseTokenizer
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class TiktokenTokenizer(BaseTokenizer):
|
11 |
+
def __post_init__(self):
|
12 |
+
self.enc = tiktoken.get_encoding(self.model_name)
|
13 |
+
|
14 |
+
def encode(self, text: str) -> List[int]:
|
15 |
+
return self.enc.encode(text)
|
16 |
+
|
17 |
+
def decode(self, token_ids: List[int]) -> str:
|
18 |
+
return self.enc.decode(token_ids)
|
graphgen/operators/__init__.py
CHANGED
@@ -1,22 +1,13 @@
|
|
|
|
1 |
from graphgen.operators.generate.generate_cot import generate_cot
|
2 |
-
from graphgen.operators.kg.extract_kg import extract_kg
|
3 |
from graphgen.operators.search.search_all import search_all
|
4 |
|
5 |
from .judge import judge_statement
|
6 |
from .quiz import quiz
|
|
|
|
|
7 |
from .traverse_graph import (
|
8 |
traverse_graph_for_aggregated,
|
9 |
traverse_graph_for_atomic,
|
10 |
traverse_graph_for_multi_hop,
|
11 |
)
|
12 |
-
|
13 |
-
__all__ = [
|
14 |
-
"extract_kg",
|
15 |
-
"quiz",
|
16 |
-
"judge_statement",
|
17 |
-
"search_all",
|
18 |
-
"traverse_graph_for_aggregated",
|
19 |
-
"traverse_graph_for_atomic",
|
20 |
-
"traverse_graph_for_multi_hop",
|
21 |
-
"generate_cot",
|
22 |
-
]
|
|
|
1 |
+
from graphgen.operators.build_kg.extract_kg import extract_kg
|
2 |
from graphgen.operators.generate.generate_cot import generate_cot
|
|
|
3 |
from graphgen.operators.search.search_all import search_all
|
4 |
|
5 |
from .judge import judge_statement
|
6 |
from .quiz import quiz
|
7 |
+
from .read import read_files
|
8 |
+
from .split import chunk_documents
|
9 |
from .traverse_graph import (
|
10 |
traverse_graph_for_aggregated,
|
11 |
traverse_graph_for_atomic,
|
12 |
traverse_graph_for_multi_hop,
|
13 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/build_kg/__init__.py
ADDED
File without changes
|
graphgen/operators/build_kg/extract_kg.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from graphgen.bases.base_storage import BaseGraphStorage
|
8 |
+
from graphgen.bases.datatypes import Chunk
|
9 |
+
from graphgen.models import OpenAIClient, Tokenizer
|
10 |
+
from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes
|
11 |
+
from graphgen.templates import KG_EXTRACTION_PROMPT
|
12 |
+
from graphgen.utils import (
|
13 |
+
detect_if_chinese,
|
14 |
+
handle_single_entity_extraction,
|
15 |
+
handle_single_relationship_extraction,
|
16 |
+
logger,
|
17 |
+
pack_history_conversations,
|
18 |
+
run_concurrent,
|
19 |
+
split_string_by_multi_markers,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
# pylint: disable=too-many-statements
|
24 |
+
async def extract_kg(
|
25 |
+
llm_client: OpenAIClient,
|
26 |
+
kg_instance: BaseGraphStorage,
|
27 |
+
tokenizer_instance: Tokenizer,
|
28 |
+
chunks: List[Chunk],
|
29 |
+
progress_bar: gr.Progress = None,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
33 |
+
:param kg_instance
|
34 |
+
:param tokenizer_instance
|
35 |
+
:param chunks
|
36 |
+
:param progress_bar: Gradio progress bar to show the progress of the extraction
|
37 |
+
:return:
|
38 |
+
"""
|
39 |
+
|
40 |
+
async def _process_single_content(chunk: Chunk, max_loop: int = 3):
|
41 |
+
chunk_id = chunk.id
|
42 |
+
content = chunk.content
|
43 |
+
if detect_if_chinese(content):
|
44 |
+
language = "Chinese"
|
45 |
+
else:
|
46 |
+
language = "English"
|
47 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
48 |
+
|
49 |
+
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
|
50 |
+
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
|
51 |
+
)
|
52 |
+
|
53 |
+
final_result = await llm_client.generate_answer(hint_prompt)
|
54 |
+
logger.info("First result: %s", final_result)
|
55 |
+
|
56 |
+
history = pack_history_conversations(hint_prompt, final_result)
|
57 |
+
for loop_index in range(max_loop):
|
58 |
+
if_loop_result = await llm_client.generate_answer(
|
59 |
+
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
|
60 |
+
)
|
61 |
+
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
62 |
+
if if_loop_result != "yes":
|
63 |
+
break
|
64 |
+
|
65 |
+
glean_result = await llm_client.generate_answer(
|
66 |
+
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
|
67 |
+
)
|
68 |
+
logger.info("Loop %s glean: %s", loop_index, glean_result)
|
69 |
+
|
70 |
+
history += pack_history_conversations(
|
71 |
+
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
|
72 |
+
)
|
73 |
+
final_result += glean_result
|
74 |
+
if loop_index == max_loop - 1:
|
75 |
+
break
|
76 |
+
|
77 |
+
records = split_string_by_multi_markers(
|
78 |
+
final_result,
|
79 |
+
[
|
80 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
|
81 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
|
82 |
+
],
|
83 |
+
)
|
84 |
+
|
85 |
+
nodes = defaultdict(list)
|
86 |
+
edges = defaultdict(list)
|
87 |
+
|
88 |
+
for record in records:
|
89 |
+
record = re.search(r"\((.*)\)", record)
|
90 |
+
if record is None:
|
91 |
+
continue
|
92 |
+
record = record.group(1) # 提取括号内的内容
|
93 |
+
record_attributes = split_string_by_multi_markers(
|
94 |
+
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
|
95 |
+
)
|
96 |
+
|
97 |
+
entity = await handle_single_entity_extraction(record_attributes, chunk_id)
|
98 |
+
if entity is not None:
|
99 |
+
nodes[entity["entity_name"]].append(entity)
|
100 |
+
continue
|
101 |
+
relation = await handle_single_relationship_extraction(
|
102 |
+
record_attributes, chunk_id
|
103 |
+
)
|
104 |
+
if relation is not None:
|
105 |
+
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
|
106 |
+
return dict(nodes), dict(edges)
|
107 |
+
|
108 |
+
results = await run_concurrent(
|
109 |
+
_process_single_content,
|
110 |
+
chunks,
|
111 |
+
desc="[2/4]Extracting entities and relationships from chunks",
|
112 |
+
unit="chunk",
|
113 |
+
progress_bar=progress_bar,
|
114 |
+
)
|
115 |
+
|
116 |
+
nodes = defaultdict(list)
|
117 |
+
edges = defaultdict(list)
|
118 |
+
for n, e in results:
|
119 |
+
for k, v in n.items():
|
120 |
+
nodes[k].extend(v)
|
121 |
+
for k, v in e.items():
|
122 |
+
edges[tuple(sorted(k))].extend(v)
|
123 |
+
|
124 |
+
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
|
125 |
+
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
|
126 |
+
|
127 |
+
return kg_instance
|
graphgen/operators/{kg → build_kg}/merge_kg.py
RENAMED
@@ -3,8 +3,8 @@ from collections import Counter
|
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
-
from graphgen.bases
|
7 |
-
from graphgen.models import Tokenizer
|
8 |
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
9 |
from graphgen.utils import detect_main_language, logger
|
10 |
from graphgen.utils.format import split_string_by_multi_markers
|
@@ -13,7 +13,7 @@ from graphgen.utils.format import split_string_by_multi_markers
|
|
13 |
async def _handle_kg_summary(
|
14 |
entity_or_relation_name: str,
|
15 |
description: str,
|
16 |
-
llm_client:
|
17 |
tokenizer_instance: Tokenizer,
|
18 |
max_summary_tokens: int = 200,
|
19 |
) -> str:
|
@@ -34,11 +34,11 @@ async def _handle_kg_summary(
|
|
34 |
language = "Chinese"
|
35 |
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
36 |
|
37 |
-
tokens = tokenizer_instance.
|
38 |
if len(tokens) < max_summary_tokens:
|
39 |
return description
|
40 |
|
41 |
-
use_description = tokenizer_instance.
|
42 |
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
|
43 |
entity_name=entity_or_relation_name,
|
44 |
description_list=use_description.split("<SEP>"),
|
@@ -54,7 +54,7 @@ async def _handle_kg_summary(
|
|
54 |
async def merge_nodes(
|
55 |
nodes_data: dict,
|
56 |
kg_instance: BaseGraphStorage,
|
57 |
-
llm_client:
|
58 |
tokenizer_instance: Tokenizer,
|
59 |
max_concurrent: int = 1000,
|
60 |
):
|
@@ -131,7 +131,7 @@ async def merge_nodes(
|
|
131 |
async def merge_edges(
|
132 |
edges_data: dict,
|
133 |
kg_instance: BaseGraphStorage,
|
134 |
-
llm_client:
|
135 |
tokenizer_instance: Tokenizer,
|
136 |
max_concurrent: int = 1000,
|
137 |
):
|
|
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
+
from graphgen.bases import BaseGraphStorage, BaseLLMClient
|
7 |
+
from graphgen.models import Tokenizer
|
8 |
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
9 |
from graphgen.utils import detect_main_language, logger
|
10 |
from graphgen.utils.format import split_string_by_multi_markers
|
|
|
13 |
async def _handle_kg_summary(
|
14 |
entity_or_relation_name: str,
|
15 |
description: str,
|
16 |
+
llm_client: BaseLLMClient,
|
17 |
tokenizer_instance: Tokenizer,
|
18 |
max_summary_tokens: int = 200,
|
19 |
) -> str:
|
|
|
34 |
language = "Chinese"
|
35 |
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
36 |
|
37 |
+
tokens = tokenizer_instance.encode(description)
|
38 |
if len(tokens) < max_summary_tokens:
|
39 |
return description
|
40 |
|
41 |
+
use_description = tokenizer_instance.decode(tokens[:max_summary_tokens])
|
42 |
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
|
43 |
entity_name=entity_or_relation_name,
|
44 |
description_list=use_description.split("<SEP>"),
|
|
|
54 |
async def merge_nodes(
|
55 |
nodes_data: dict,
|
56 |
kg_instance: BaseGraphStorage,
|
57 |
+
llm_client: BaseLLMClient,
|
58 |
tokenizer_instance: Tokenizer,
|
59 |
max_concurrent: int = 1000,
|
60 |
):
|
|
|
131 |
async def merge_edges(
|
132 |
edges_data: dict,
|
133 |
kg_instance: BaseGraphStorage,
|
134 |
+
llm_client: BaseLLMClient,
|
135 |
tokenizer_instance: Tokenizer,
|
136 |
max_concurrent: int = 1000,
|
137 |
):
|
graphgen/operators/{kg → build_kg}/split_kg.py
RENAMED
File without changes
|
graphgen/operators/generate/generate_cot.py
CHANGED
@@ -3,14 +3,14 @@ from typing import Dict, List, Tuple
|
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
-
from graphgen.models import CommunityDetector, NetworkXStorage,
|
7 |
from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
|
8 |
from graphgen.utils import compute_content_hash, detect_main_language
|
9 |
|
10 |
|
11 |
async def generate_cot(
|
12 |
graph_storage: NetworkXStorage,
|
13 |
-
synthesizer_llm_client:
|
14 |
method_params: Dict = None,
|
15 |
):
|
16 |
method = method_params.get("method", "leiden")
|
|
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
+
from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIClient
|
7 |
from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
|
8 |
from graphgen.utils import compute_content_hash, detect_main_language
|
9 |
|
10 |
|
11 |
async def generate_cot(
|
12 |
graph_storage: NetworkXStorage,
|
13 |
+
synthesizer_llm_client: OpenAIClient,
|
14 |
method_params: Dict = None,
|
15 |
):
|
16 |
method = method_params.get("method", "leiden")
|
graphgen/operators/judge.py
CHANGED
@@ -3,13 +3,13 @@ import math
|
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
-
from graphgen.models import JsonKVStorage, NetworkXStorage,
|
7 |
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
|
8 |
from graphgen.utils import logger, yes_no_loss_entropy
|
9 |
|
10 |
|
11 |
async def judge_statement( # pylint: disable=too-many-statements
|
12 |
-
trainee_llm_client:
|
13 |
graph_storage: NetworkXStorage,
|
14 |
rephrase_storage: JsonKVStorage,
|
15 |
re_judge: bool = False,
|
|
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
|
6 |
+
from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
|
7 |
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
|
8 |
from graphgen.utils import logger, yes_no_loss_entropy
|
9 |
|
10 |
|
11 |
async def judge_statement( # pylint: disable=too-many-statements
|
12 |
+
trainee_llm_client: OpenAIClient,
|
13 |
graph_storage: NetworkXStorage,
|
14 |
rephrase_storage: JsonKVStorage,
|
15 |
re_judge: bool = False,
|
graphgen/operators/kg/extract_kg.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
import asyncio
|
2 |
-
import re
|
3 |
-
from collections import defaultdict
|
4 |
-
from typing import List
|
5 |
-
|
6 |
-
import gradio as gr
|
7 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
8 |
-
|
9 |
-
from graphgen.bases.base_storage import BaseGraphStorage
|
10 |
-
from graphgen.bases.datatypes import Chunk
|
11 |
-
from graphgen.models import OpenAIModel, Tokenizer
|
12 |
-
from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
|
13 |
-
from graphgen.templates import KG_EXTRACTION_PROMPT
|
14 |
-
from graphgen.utils import (
|
15 |
-
detect_if_chinese,
|
16 |
-
handle_single_entity_extraction,
|
17 |
-
handle_single_relationship_extraction,
|
18 |
-
logger,
|
19 |
-
pack_history_conversations,
|
20 |
-
split_string_by_multi_markers,
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
# pylint: disable=too-many-statements
|
25 |
-
async def extract_kg(
|
26 |
-
llm_client: OpenAIModel,
|
27 |
-
kg_instance: BaseGraphStorage,
|
28 |
-
tokenizer_instance: Tokenizer,
|
29 |
-
chunks: List[Chunk],
|
30 |
-
progress_bar: gr.Progress = None,
|
31 |
-
max_concurrent: int = 1000,
|
32 |
-
):
|
33 |
-
"""
|
34 |
-
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
35 |
-
:param kg_instance
|
36 |
-
:param tokenizer_instance
|
37 |
-
:param chunks
|
38 |
-
:param progress_bar: Gradio progress bar to show the progress of the extraction
|
39 |
-
:param max_concurrent
|
40 |
-
:return:
|
41 |
-
"""
|
42 |
-
|
43 |
-
semaphore = asyncio.Semaphore(max_concurrent)
|
44 |
-
|
45 |
-
async def _process_single_content(chunk: Chunk, max_loop: int = 3):
|
46 |
-
async with semaphore:
|
47 |
-
chunk_id = chunk.id
|
48 |
-
content = chunk.content
|
49 |
-
if detect_if_chinese(content):
|
50 |
-
language = "Chinese"
|
51 |
-
else:
|
52 |
-
language = "English"
|
53 |
-
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
54 |
-
|
55 |
-
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
|
56 |
-
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
|
57 |
-
)
|
58 |
-
|
59 |
-
final_result = await llm_client.generate_answer(hint_prompt)
|
60 |
-
logger.info("First result: %s", final_result)
|
61 |
-
|
62 |
-
history = pack_history_conversations(hint_prompt, final_result)
|
63 |
-
for loop_index in range(max_loop):
|
64 |
-
if_loop_result = await llm_client.generate_answer(
|
65 |
-
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
|
66 |
-
)
|
67 |
-
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
68 |
-
if if_loop_result != "yes":
|
69 |
-
break
|
70 |
-
|
71 |
-
glean_result = await llm_client.generate_answer(
|
72 |
-
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
|
73 |
-
)
|
74 |
-
logger.info("Loop %s glean: %s", loop_index, glean_result)
|
75 |
-
|
76 |
-
history += pack_history_conversations(
|
77 |
-
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
|
78 |
-
)
|
79 |
-
final_result += glean_result
|
80 |
-
if loop_index == max_loop - 1:
|
81 |
-
break
|
82 |
-
|
83 |
-
records = split_string_by_multi_markers(
|
84 |
-
final_result,
|
85 |
-
[
|
86 |
-
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
|
87 |
-
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
|
88 |
-
],
|
89 |
-
)
|
90 |
-
|
91 |
-
nodes = defaultdict(list)
|
92 |
-
edges = defaultdict(list)
|
93 |
-
|
94 |
-
for record in records:
|
95 |
-
record = re.search(r"\((.*)\)", record)
|
96 |
-
if record is None:
|
97 |
-
continue
|
98 |
-
record = record.group(1) # 提取括号内的内容
|
99 |
-
record_attributes = split_string_by_multi_markers(
|
100 |
-
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
|
101 |
-
)
|
102 |
-
|
103 |
-
entity = await handle_single_entity_extraction(
|
104 |
-
record_attributes, chunk_id
|
105 |
-
)
|
106 |
-
if entity is not None:
|
107 |
-
nodes[entity["entity_name"]].append(entity)
|
108 |
-
continue
|
109 |
-
relation = await handle_single_relationship_extraction(
|
110 |
-
record_attributes, chunk_id
|
111 |
-
)
|
112 |
-
if relation is not None:
|
113 |
-
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
|
114 |
-
return dict(nodes), dict(edges)
|
115 |
-
|
116 |
-
results = []
|
117 |
-
chunk_number = len(chunks)
|
118 |
-
async for result in tqdm_async(
|
119 |
-
asyncio.as_completed([_process_single_content(c) for c in chunks]),
|
120 |
-
total=len(chunks),
|
121 |
-
desc="[2/4]Extracting entities and relationships from chunks",
|
122 |
-
unit="chunk",
|
123 |
-
):
|
124 |
-
try:
|
125 |
-
if progress_bar is not None:
|
126 |
-
progress_bar(
|
127 |
-
len(results) / chunk_number,
|
128 |
-
desc="[3/4]Extracting entities and relationships from chunks",
|
129 |
-
)
|
130 |
-
results.append(await result)
|
131 |
-
if progress_bar is not None and len(results) == chunk_number:
|
132 |
-
progress_bar(
|
133 |
-
1, desc="[3/4]Extracting entities and relationships from chunks"
|
134 |
-
)
|
135 |
-
except Exception as e: # pylint: disable=broad-except
|
136 |
-
logger.error(
|
137 |
-
"Error occurred while extracting entities and relationships from chunks: %s",
|
138 |
-
e,
|
139 |
-
)
|
140 |
-
|
141 |
-
nodes = defaultdict(list)
|
142 |
-
edges = defaultdict(list)
|
143 |
-
for n, e in results:
|
144 |
-
for k, v in n.items():
|
145 |
-
nodes[k].extend(v)
|
146 |
-
for k, v in e.items():
|
147 |
-
edges[tuple(sorted(k))].extend(v)
|
148 |
-
|
149 |
-
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
|
150 |
-
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
|
151 |
-
|
152 |
-
return kg_instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/preprocess/resolute_coreference.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
from typing import List
|
2 |
|
3 |
from graphgen.bases.datatypes import Chunk
|
4 |
-
from graphgen.models import
|
5 |
from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
|
6 |
from graphgen.utils import detect_main_language
|
7 |
|
8 |
|
9 |
async def resolute_coreference(
|
10 |
-
llm_client:
|
11 |
) -> List[Chunk]:
|
12 |
"""
|
13 |
Resolute conference
|
|
|
1 |
from typing import List
|
2 |
|
3 |
from graphgen.bases.datatypes import Chunk
|
4 |
+
from graphgen.models import OpenAIClient
|
5 |
from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
|
6 |
from graphgen.utils import detect_main_language
|
7 |
|
8 |
|
9 |
async def resolute_coreference(
|
10 |
+
llm_client: OpenAIClient, chunks: List[Chunk]
|
11 |
) -> List[Chunk]:
|
12 |
"""
|
13 |
Resolute conference
|
graphgen/operators/quiz.py
CHANGED
@@ -2,17 +2,19 @@ import asyncio
|
|
2 |
from collections import defaultdict
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
-
|
6 |
-
from graphgen.
|
7 |
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
|
|
|
8 |
|
9 |
|
10 |
async def quiz(
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
16 |
"""
|
17 |
Get all edges and quiz them
|
18 |
|
@@ -26,11 +28,7 @@ async def quiz(
|
|
26 |
|
27 |
semaphore = asyncio.Semaphore(max_concurrent)
|
28 |
|
29 |
-
async def _process_single_quiz(
|
30 |
-
des: str,
|
31 |
-
prompt: str,
|
32 |
-
gt: str
|
33 |
-
):
|
34 |
async with semaphore:
|
35 |
try:
|
36 |
# 如果在rephrase_storage中已经存在,直接取出
|
@@ -39,16 +37,14 @@ async def quiz(
|
|
39 |
return None
|
40 |
|
41 |
new_description = await synth_llm_client.generate_answer(
|
42 |
-
prompt,
|
43 |
-
temperature=1
|
44 |
)
|
45 |
-
return
|
46 |
|
47 |
-
except Exception as e:
|
48 |
logger.error("Error when quizzing description %s: %s", des, e)
|
49 |
return None
|
50 |
|
51 |
-
|
52 |
edges = await graph_storage.get_all_edges()
|
53 |
nodes = await graph_storage.get_all_nodes()
|
54 |
|
@@ -60,41 +56,59 @@ async def quiz(
|
|
60 |
description = edge_data["description"]
|
61 |
language = "English" if detect_main_language(description) == "en" else "Chinese"
|
62 |
|
63 |
-
results[description] = [(description,
|
64 |
|
65 |
for i in range(max_samples):
|
66 |
if i > 0:
|
67 |
tasks.append(
|
68 |
-
_process_single_quiz(
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
)
|
72 |
-
tasks.append(
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
for node in nodes:
|
77 |
node_data = node[1]
|
78 |
description = node_data["description"]
|
79 |
language = "English" if detect_main_language(description) == "en" else "Chinese"
|
80 |
|
81 |
-
results[description] = [(description,
|
82 |
|
83 |
for i in range(max_samples):
|
84 |
if i > 0:
|
85 |
tasks.append(
|
86 |
-
_process_single_quiz(
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
-
|
91 |
-
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
|
92 |
-
input_sentence=description), 'no'))
|
93 |
|
94 |
for result in tqdm_async(
|
95 |
-
|
96 |
-
total=len(tasks),
|
97 |
-
desc="Quizzing descriptions"
|
98 |
):
|
99 |
new_result = await result
|
100 |
if new_result:
|
@@ -105,5 +119,4 @@ async def quiz(
|
|
105 |
results[key] = list(set(value))
|
106 |
await rephrase_storage.upsert({key: results[key]})
|
107 |
|
108 |
-
|
109 |
return rephrase_storage
|
|
|
2 |
from collections import defaultdict
|
3 |
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
|
6 |
+
from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
|
7 |
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
|
8 |
+
from graphgen.utils import detect_main_language, logger
|
9 |
|
10 |
|
11 |
async def quiz(
|
12 |
+
synth_llm_client: OpenAIClient,
|
13 |
+
graph_storage: NetworkXStorage,
|
14 |
+
rephrase_storage: JsonKVStorage,
|
15 |
+
max_samples: int = 1,
|
16 |
+
max_concurrent: int = 1000,
|
17 |
+
) -> JsonKVStorage:
|
18 |
"""
|
19 |
Get all edges and quiz them
|
20 |
|
|
|
28 |
|
29 |
semaphore = asyncio.Semaphore(max_concurrent)
|
30 |
|
31 |
+
async def _process_single_quiz(des: str, prompt: str, gt: str):
|
|
|
|
|
|
|
|
|
32 |
async with semaphore:
|
33 |
try:
|
34 |
# 如果在rephrase_storage中已经存在,直接取出
|
|
|
37 |
return None
|
38 |
|
39 |
new_description = await synth_llm_client.generate_answer(
|
40 |
+
prompt, temperature=1
|
|
|
41 |
)
|
42 |
+
return {des: [(new_description, gt)]}
|
43 |
|
44 |
+
except Exception as e: # pylint: disable=broad-except
|
45 |
logger.error("Error when quizzing description %s: %s", des, e)
|
46 |
return None
|
47 |
|
|
|
48 |
edges = await graph_storage.get_all_edges()
|
49 |
nodes = await graph_storage.get_all_nodes()
|
50 |
|
|
|
56 |
description = edge_data["description"]
|
57 |
language = "English" if detect_main_language(description) == "en" else "Chinese"
|
58 |
|
59 |
+
results[description] = [(description, "yes")]
|
60 |
|
61 |
for i in range(max_samples):
|
62 |
if i > 0:
|
63 |
tasks.append(
|
64 |
+
_process_single_quiz(
|
65 |
+
description,
|
66 |
+
DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format(
|
67 |
+
input_sentence=description
|
68 |
+
),
|
69 |
+
"yes",
|
70 |
+
)
|
71 |
)
|
72 |
+
tasks.append(
|
73 |
+
_process_single_quiz(
|
74 |
+
description,
|
75 |
+
DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format(
|
76 |
+
input_sentence=description
|
77 |
+
),
|
78 |
+
"no",
|
79 |
+
)
|
80 |
+
)
|
81 |
|
82 |
for node in nodes:
|
83 |
node_data = node[1]
|
84 |
description = node_data["description"]
|
85 |
language = "English" if detect_main_language(description) == "en" else "Chinese"
|
86 |
|
87 |
+
results[description] = [(description, "yes")]
|
88 |
|
89 |
for i in range(max_samples):
|
90 |
if i > 0:
|
91 |
tasks.append(
|
92 |
+
_process_single_quiz(
|
93 |
+
description,
|
94 |
+
DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format(
|
95 |
+
input_sentence=description
|
96 |
+
),
|
97 |
+
"yes",
|
98 |
+
)
|
99 |
+
)
|
100 |
+
tasks.append(
|
101 |
+
_process_single_quiz(
|
102 |
+
description,
|
103 |
+
DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format(
|
104 |
+
input_sentence=description
|
105 |
+
),
|
106 |
+
"no",
|
107 |
)
|
108 |
+
)
|
|
|
|
|
109 |
|
110 |
for result in tqdm_async(
|
111 |
+
asyncio.as_completed(tasks), total=len(tasks), desc="Quizzing descriptions"
|
|
|
|
|
112 |
):
|
113 |
new_result = await result
|
114 |
if new_result:
|
|
|
119 |
results[key] = list(set(value))
|
120 |
await rephrase_storage.upsert({key: results[key]})
|
121 |
|
|
|
122 |
return rephrase_storage
|
graphgen/operators/read/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .read_files import read_files
|
graphgen/operators/read/read_files.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from graphgen.models import CsvReader, JsonlReader, JsonReader, TxtReader
|
2 |
+
|
3 |
+
_MAPPING = {
|
4 |
+
"jsonl": JsonlReader,
|
5 |
+
"json": JsonReader,
|
6 |
+
"txt": TxtReader,
|
7 |
+
"csv": CsvReader,
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
def read_files(file_path: str):
|
12 |
+
suffix = file_path.split(".")[-1]
|
13 |
+
if suffix in _MAPPING:
|
14 |
+
reader = _MAPPING[suffix]()
|
15 |
+
else:
|
16 |
+
raise ValueError(
|
17 |
+
f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}"
|
18 |
+
)
|
19 |
+
return reader.read(file_path)
|
graphgen/operators/split/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .split_chunks import chunk_documents
|
graphgen/operators/split/split_chunks.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
|
6 |
+
from graphgen.models import (
|
7 |
+
ChineseRecursiveTextSplitter,
|
8 |
+
RecursiveCharacterSplitter,
|
9 |
+
Tokenizer,
|
10 |
+
)
|
11 |
+
from graphgen.utils import compute_content_hash, detect_main_language
|
12 |
+
|
13 |
+
_MAPPING = {
|
14 |
+
"en": RecursiveCharacterSplitter,
|
15 |
+
"zh": ChineseRecursiveTextSplitter,
|
16 |
+
}
|
17 |
+
|
18 |
+
SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
|
19 |
+
|
20 |
+
|
21 |
+
@lru_cache(maxsize=None)
|
22 |
+
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
|
23 |
+
cls = _MAPPING[language]
|
24 |
+
kwargs = dict(frozen_kwargs)
|
25 |
+
return cls(**kwargs)
|
26 |
+
|
27 |
+
|
28 |
+
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
|
29 |
+
if language not in _MAPPING:
|
30 |
+
raise ValueError(
|
31 |
+
f"Unsupported language: {language}. "
|
32 |
+
f"Supported languages are: {list(_MAPPING.keys())}"
|
33 |
+
)
|
34 |
+
splitter = _get_splitter(language, frozenset(kwargs.items()))
|
35 |
+
return splitter.split_text(text)
|
36 |
+
|
37 |
+
|
38 |
+
async def chunk_documents(
|
39 |
+
new_docs: dict,
|
40 |
+
chunk_size: int = 1024,
|
41 |
+
chunk_overlap: int = 100,
|
42 |
+
tokenizer_instance: Tokenizer = None,
|
43 |
+
progress_bar=None,
|
44 |
+
) -> dict:
|
45 |
+
inserting_chunks = {}
|
46 |
+
cur_index = 1
|
47 |
+
doc_number = len(new_docs)
|
48 |
+
async for doc_key, doc in tqdm_async(
|
49 |
+
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
50 |
+
):
|
51 |
+
doc_language = detect_main_language(doc["content"])
|
52 |
+
text_chunks = split_chunks(
|
53 |
+
doc["content"],
|
54 |
+
language=doc_language,
|
55 |
+
chunk_size=chunk_size,
|
56 |
+
chunk_overlap=chunk_overlap,
|
57 |
+
)
|
58 |
+
|
59 |
+
chunks = {
|
60 |
+
compute_content_hash(txt, prefix="chunk-"): {
|
61 |
+
"content": txt,
|
62 |
+
"full_doc_id": doc_key,
|
63 |
+
"length": len(tokenizer_instance.encode(txt))
|
64 |
+
if tokenizer_instance
|
65 |
+
else len(txt),
|
66 |
+
"language": doc_language,
|
67 |
+
}
|
68 |
+
for txt in text_chunks
|
69 |
+
}
|
70 |
+
inserting_chunks.update(chunks)
|
71 |
+
|
72 |
+
if progress_bar is not None:
|
73 |
+
progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
|
74 |
+
cur_index += 1
|
75 |
+
|
76 |
+
return inserting_chunks
|
graphgen/operators/traverse_graph.py
CHANGED
@@ -6,11 +6,11 @@ from tqdm.asyncio import tqdm as tqdm_async
|
|
6 |
from graphgen.models import (
|
7 |
JsonKVStorage,
|
8 |
NetworkXStorage,
|
9 |
-
|
10 |
Tokenizer,
|
11 |
TraverseStrategy,
|
12 |
)
|
13 |
-
from graphgen.operators.
|
14 |
from graphgen.templates import (
|
15 |
ANSWER_REPHRASING_PROMPT,
|
16 |
MULTI_HOP_GENERATION_PROMPT,
|
@@ -30,7 +30,7 @@ async def _pre_tokenize(
|
|
30 |
if "length" not in edge[2]:
|
31 |
edge[2]["length"] = len(
|
32 |
await asyncio.get_event_loop().run_in_executor(
|
33 |
-
None, tokenizer.
|
34 |
)
|
35 |
)
|
36 |
return edge
|
@@ -40,7 +40,7 @@ async def _pre_tokenize(
|
|
40 |
if "length" not in node[1]:
|
41 |
node[1]["length"] = len(
|
42 |
await asyncio.get_event_loop().run_in_executor(
|
43 |
-
None, tokenizer.
|
44 |
)
|
45 |
)
|
46 |
return node
|
@@ -161,7 +161,7 @@ def _post_process_synthetic_data(data):
|
|
161 |
|
162 |
|
163 |
async def traverse_graph_for_aggregated(
|
164 |
-
llm_client:
|
165 |
tokenizer: Tokenizer,
|
166 |
graph_storage: NetworkXStorage,
|
167 |
traverse_strategy: TraverseStrategy,
|
@@ -310,7 +310,7 @@ async def traverse_graph_for_aggregated(
|
|
310 |
|
311 |
# pylint: disable=too-many-branches, too-many-statements
|
312 |
async def traverse_graph_for_atomic(
|
313 |
-
llm_client:
|
314 |
tokenizer: Tokenizer,
|
315 |
graph_storage: NetworkXStorage,
|
316 |
traverse_strategy: TraverseStrategy,
|
@@ -426,7 +426,7 @@ async def traverse_graph_for_atomic(
|
|
426 |
|
427 |
|
428 |
async def traverse_graph_for_multi_hop(
|
429 |
-
llm_client:
|
430 |
tokenizer: Tokenizer,
|
431 |
graph_storage: NetworkXStorage,
|
432 |
traverse_strategy: TraverseStrategy,
|
|
|
6 |
from graphgen.models import (
|
7 |
JsonKVStorage,
|
8 |
NetworkXStorage,
|
9 |
+
OpenAIClient,
|
10 |
Tokenizer,
|
11 |
TraverseStrategy,
|
12 |
)
|
13 |
+
from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
|
14 |
from graphgen.templates import (
|
15 |
ANSWER_REPHRASING_PROMPT,
|
16 |
MULTI_HOP_GENERATION_PROMPT,
|
|
|
30 |
if "length" not in edge[2]:
|
31 |
edge[2]["length"] = len(
|
32 |
await asyncio.get_event_loop().run_in_executor(
|
33 |
+
None, tokenizer.encode, edge[2]["description"]
|
34 |
)
|
35 |
)
|
36 |
return edge
|
|
|
40 |
if "length" not in node[1]:
|
41 |
node[1]["length"] = len(
|
42 |
await asyncio.get_event_loop().run_in_executor(
|
43 |
+
None, tokenizer.encode, node[1]["description"]
|
44 |
)
|
45 |
)
|
46 |
return node
|
|
|
161 |
|
162 |
|
163 |
async def traverse_graph_for_aggregated(
|
164 |
+
llm_client: OpenAIClient,
|
165 |
tokenizer: Tokenizer,
|
166 |
graph_storage: NetworkXStorage,
|
167 |
traverse_strategy: TraverseStrategy,
|
|
|
310 |
|
311 |
# pylint: disable=too-many-branches, too-many-statements
|
312 |
async def traverse_graph_for_atomic(
|
313 |
+
llm_client: OpenAIClient,
|
314 |
tokenizer: Tokenizer,
|
315 |
graph_storage: NetworkXStorage,
|
316 |
traverse_strategy: TraverseStrategy,
|
|
|
426 |
|
427 |
|
428 |
async def traverse_graph_for_multi_hop(
|
429 |
+
llm_client: OpenAIClient,
|
430 |
tokenizer: Tokenizer,
|
431 |
graph_storage: NetworkXStorage,
|
432 |
traverse_strategy: TraverseStrategy,
|
graphgen/utils/__init__.py
CHANGED
@@ -13,3 +13,5 @@ from .hash import compute_args_hash, compute_content_hash
|
|
13 |
from .help_nltk import NLTKHelper
|
14 |
from .log import logger, parse_log, set_logger
|
15 |
from .loop import create_event_loop
|
|
|
|
|
|
13 |
from .help_nltk import NLTKHelper
|
14 |
from .log import logger, parse_log, set_logger
|
15 |
from .loop import create_event_loop
|
16 |
+
from .run_concurrent import run_concurrent
|
17 |
+
from .wrap import async_to_sync_method
|
graphgen/utils/calculate_confidence.py
CHANGED
@@ -1,34 +1,41 @@
|
|
1 |
import math
|
2 |
from typing import List
|
3 |
-
|
|
|
|
|
4 |
|
5 |
def preprocess_tokens(tokens: List[Token]) -> List[Token]:
|
6 |
"""Preprocess tokens for calculating confidence."""
|
7 |
tokens = [x for x in tokens if x.prob > 0]
|
8 |
return tokens
|
9 |
|
|
|
10 |
def joint_probability(tokens: List[Token]) -> float:
|
11 |
"""Calculate joint probability of a list of tokens."""
|
12 |
tokens = preprocess_tokens(tokens)
|
13 |
logprob_sum = sum(x.logprob for x in tokens)
|
14 |
return math.exp(logprob_sum / len(tokens))
|
15 |
|
|
|
16 |
def min_prob(tokens: List[Token]) -> float:
|
17 |
"""Calculate the minimum probability of a list of tokens."""
|
18 |
tokens = preprocess_tokens(tokens)
|
19 |
return min(x.prob for x in tokens)
|
20 |
|
|
|
21 |
def average_prob(tokens: List[Token]) -> float:
|
22 |
"""Calculate the average probability of a list of tokens."""
|
23 |
tokens = preprocess_tokens(tokens)
|
24 |
return sum(x.prob for x in tokens) / len(tokens)
|
25 |
|
|
|
26 |
def average_confidence(tokens: List[Token]) -> float:
|
27 |
"""Calculate the average confidence of a list of tokens."""
|
28 |
tokens = preprocess_tokens(tokens)
|
29 |
confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens]
|
30 |
return sum(confidence) / len(tokens)
|
31 |
|
|
|
32 |
def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
|
33 |
"""Calculate the loss for yes/no question."""
|
34 |
losses = []
|
@@ -41,7 +48,10 @@ def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> floa
|
|
41 |
losses.append(token.prob)
|
42 |
return sum(losses) / len(losses)
|
43 |
|
44 |
-
|
|
|
|
|
|
|
45 |
"""Calculate the loss for yes/no question using entropy."""
|
46 |
losses = []
|
47 |
for i, tokens in enumerate(tokens_list):
|
|
|
1 |
import math
|
2 |
from typing import List
|
3 |
+
|
4 |
+
from graphgen.bases.datatypes import Token
|
5 |
+
|
6 |
|
7 |
def preprocess_tokens(tokens: List[Token]) -> List[Token]:
|
8 |
"""Preprocess tokens for calculating confidence."""
|
9 |
tokens = [x for x in tokens if x.prob > 0]
|
10 |
return tokens
|
11 |
|
12 |
+
|
13 |
def joint_probability(tokens: List[Token]) -> float:
|
14 |
"""Calculate joint probability of a list of tokens."""
|
15 |
tokens = preprocess_tokens(tokens)
|
16 |
logprob_sum = sum(x.logprob for x in tokens)
|
17 |
return math.exp(logprob_sum / len(tokens))
|
18 |
|
19 |
+
|
20 |
def min_prob(tokens: List[Token]) -> float:
|
21 |
"""Calculate the minimum probability of a list of tokens."""
|
22 |
tokens = preprocess_tokens(tokens)
|
23 |
return min(x.prob for x in tokens)
|
24 |
|
25 |
+
|
26 |
def average_prob(tokens: List[Token]) -> float:
|
27 |
"""Calculate the average probability of a list of tokens."""
|
28 |
tokens = preprocess_tokens(tokens)
|
29 |
return sum(x.prob for x in tokens) / len(tokens)
|
30 |
|
31 |
+
|
32 |
def average_confidence(tokens: List[Token]) -> float:
|
33 |
"""Calculate the average confidence of a list of tokens."""
|
34 |
tokens = preprocess_tokens(tokens)
|
35 |
confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens]
|
36 |
return sum(confidence) / len(tokens)
|
37 |
|
38 |
+
|
39 |
def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
|
40 |
"""Calculate the loss for yes/no question."""
|
41 |
losses = []
|
|
|
48 |
losses.append(token.prob)
|
49 |
return sum(losses) / len(losses)
|
50 |
|
51 |
+
|
52 |
+
def yes_no_loss_entropy(
|
53 |
+
tokens_list: List[List[Token]], ground_truth: List[str]
|
54 |
+
) -> float:
|
55 |
"""Calculate the loss for yes/no question using entropy."""
|
56 |
losses = []
|
57 |
for i, tokens in enumerate(tokens_list):
|
graphgen/utils/run_concurrent.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from typing import Awaitable, Callable, List, Optional, TypeVar
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
6 |
+
|
7 |
+
from graphgen.utils.log import logger
|
8 |
+
|
9 |
+
T = TypeVar("T")
|
10 |
+
R = TypeVar("R")
|
11 |
+
|
12 |
+
|
13 |
+
async def run_concurrent(
|
14 |
+
coro_fn: Callable[[T], Awaitable[R]],
|
15 |
+
items: List[T],
|
16 |
+
*,
|
17 |
+
desc: str = "processing",
|
18 |
+
unit: str = "item",
|
19 |
+
progress_bar: Optional[gr.Progress] = None,
|
20 |
+
) -> List[R]:
|
21 |
+
tasks = [asyncio.create_task(coro_fn(it)) for it in items]
|
22 |
+
|
23 |
+
results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
|
24 |
+
|
25 |
+
ok_results = []
|
26 |
+
for idx, res in enumerate(results):
|
27 |
+
if isinstance(res, Exception):
|
28 |
+
logger.exception("Task failed: %s", res)
|
29 |
+
if progress_bar:
|
30 |
+
progress_bar((idx + 1) / len(items), desc=desc)
|
31 |
+
continue
|
32 |
+
ok_results.append(res)
|
33 |
+
if progress_bar:
|
34 |
+
progress_bar((idx + 1) / len(items), desc=desc)
|
35 |
+
|
36 |
+
if progress_bar:
|
37 |
+
progress_bar(1.0, desc=desc)
|
38 |
+
return ok_results
|
graphgen/utils/wrap.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
from typing import Any, Callable
|
3 |
+
|
4 |
+
from .loop import create_event_loop
|
5 |
+
|
6 |
+
|
7 |
+
def async_to_sync_method(func: Callable) -> Callable:
|
8 |
+
@wraps(func)
|
9 |
+
def wrapper(self, *args, **kwargs) -> Any:
|
10 |
+
loop = create_event_loop()
|
11 |
+
return loop.run_until_complete(func(self, *args, **kwargs))
|
12 |
+
|
13 |
+
return wrapper
|
webui/app.py
CHANGED
@@ -9,7 +9,7 @@ import pandas as pd
|
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
from graphgen.graphgen import GraphGen
|
12 |
-
from graphgen.models import
|
13 |
from graphgen.models.llm.limitter import RPM, TPM
|
14 |
from graphgen.utils import set_logger
|
15 |
from webui.base import WebuiParams
|
@@ -41,7 +41,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
41 |
|
42 |
graph_gen = GraphGen(working_dir=working_dir, config=config)
|
43 |
# Set up LLM clients
|
44 |
-
graph_gen.synthesizer_llm_client =
|
45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
@@ -50,7 +50,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
50 |
tpm=TPM(env.get("TPM", 50000)),
|
51 |
)
|
52 |
|
53 |
-
graph_gen.trainee_llm_client =
|
54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
from graphgen.graphgen import GraphGen
|
12 |
+
from graphgen.models import OpenAIClient, Tokenizer
|
13 |
from graphgen.models.llm.limitter import RPM, TPM
|
14 |
from graphgen.utils import set_logger
|
15 |
from webui.base import WebuiParams
|
|
|
41 |
|
42 |
graph_gen = GraphGen(working_dir=working_dir, config=config)
|
43 |
# Set up LLM clients
|
44 |
+
graph_gen.synthesizer_llm_client = OpenAIClient(
|
45 |
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
46 |
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
47 |
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
|
|
50 |
tpm=TPM(env.get("TPM", 50000)),
|
51 |
)
|
52 |
|
53 |
+
graph_gen.trainee_llm_client = OpenAIClient(
|
54 |
model_name=env.get("TRAINEE_MODEL", ""),
|
55 |
base_url=env.get("TRAINEE_BASE_URL", ""),
|
56 |
api_key=env.get("TRAINEE_API_KEY", ""),
|
webui/utils/count_tokens.py
CHANGED
@@ -45,7 +45,7 @@ def count_tokens(file, tokenizer_name, data_frame):
|
|
45 |
content = item.get("content", "")
|
46 |
else:
|
47 |
content = item
|
48 |
-
token_count += len(tokenizer.
|
49 |
|
50 |
_update_data = [[str(token_count), str(token_count * 50), "N/A"]]
|
51 |
|
|
|
45 |
content = item.get("content", "")
|
46 |
else:
|
47 |
content = item
|
48 |
+
token_count += len(tokenizer.encode(content))
|
49 |
|
50 |
_update_data = [[str(token_count), str(token_count * 50), "N/A"]]
|
51 |
|