Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
43d27f2
1
Parent(s):
d2a63cc
Auto-sync from demo at Wed Sep 24 09:52:41 UTC 2025
Browse files- app.py +27 -12
- graphgen/bases/base_splitter.py +135 -0
- graphgen/bases/datatypes.py +18 -0
- graphgen/{models/text → configs}/__init__.py +0 -0
- graphgen/configs/aggregated_config.yaml +5 -1
- graphgen/configs/atomic_config.yaml +5 -1
- graphgen/configs/cot_config.yaml +5 -1
- graphgen/configs/multi_hop_config.yaml +5 -1
- graphgen/evaluate.py +82 -52
- graphgen/graphgen.py +17 -12
- graphgen/models/__init__.py +1 -30
- graphgen/models/evaluate/base_evaluator.py +9 -7
- graphgen/models/evaluate/length_evaluator.py +5 -5
- graphgen/models/evaluate/mtld_evaluator.py +13 -8
- graphgen/models/evaluate/reward_evaluator.py +13 -7
- graphgen/models/evaluate/uni_evaluator.py +46 -22
- graphgen/models/splitter/__init__.py +31 -0
- graphgen/models/splitter/character_splitter.py +26 -0
- graphgen/models/splitter/markdown_splitter.py +33 -0
- graphgen/models/splitter/recursive_character_splitter.py +149 -0
- graphgen/models/text/chunk.py +0 -7
- graphgen/models/text/text_pair.py +0 -9
- graphgen/operators/kg/extract_kg.py +2 -1
- graphgen/operators/preprocess/resolute_coreference.py +2 -1
- webui/app.py +27 -12
- webui/base.py +2 -1
app.py
CHANGED
|
@@ -12,7 +12,7 @@ from graphgen.graphgen import GraphGen
|
|
| 12 |
from graphgen.models import OpenAIModel, Tokenizer
|
| 13 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 14 |
from graphgen.utils import set_logger
|
| 15 |
-
from webui.base import
|
| 16 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
| 17 |
from webui.count_tokens import count_tokens
|
| 18 |
from webui.i18n import Translate
|
|
@@ -66,13 +66,19 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 66 |
|
| 67 |
|
| 68 |
# pylint: disable=too-many-statements
|
| 69 |
-
def run_graphgen(params, progress=gr.Progress()):
|
| 70 |
def sum_tokens(client):
|
| 71 |
return sum(u["total_tokens"] for u in client.token_usage)
|
| 72 |
|
| 73 |
config = {
|
| 74 |
"if_trainee_model": params.if_trainee_model,
|
| 75 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
"output_data_type": params.output_data_type,
|
| 77 |
"output_data_format": params.output_data_format,
|
| 78 |
"tokenizer": params.tokenizer,
|
|
@@ -91,7 +97,6 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 91 |
"isolated_node_strategy": params.isolated_node_strategy,
|
| 92 |
"loss_strategy": params.loss_strategy,
|
| 93 |
},
|
| 94 |
-
"chunk_size": params.chunk_size,
|
| 95 |
}
|
| 96 |
|
| 97 |
env = {
|
|
@@ -284,10 +289,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 284 |
label="Chunk Size",
|
| 285 |
minimum=256,
|
| 286 |
maximum=4096,
|
| 287 |
-
value=
|
| 288 |
step=256,
|
| 289 |
interactive=True,
|
| 290 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
tokenizer = gr.Textbox(
|
| 292 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 293 |
)
|
|
@@ -499,7 +512,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 499 |
|
| 500 |
submit_btn.click(
|
| 501 |
lambda *args: run_graphgen(
|
| 502 |
-
|
| 503 |
if_trainee_model=args[0],
|
| 504 |
input_file=args[1],
|
| 505 |
tokenizer=args[2],
|
|
@@ -518,12 +531,13 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 518 |
trainee_model=args[15],
|
| 519 |
api_key=args[16],
|
| 520 |
chunk_size=args[17],
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
|
|
|
| 527 |
)
|
| 528 |
),
|
| 529 |
inputs=[
|
|
@@ -545,6 +559,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 545 |
trainee_model,
|
| 546 |
api_key,
|
| 547 |
chunk_size,
|
|
|
|
| 548 |
rpm,
|
| 549 |
tpm,
|
| 550 |
quiz_samples,
|
|
|
|
| 12 |
from graphgen.models import OpenAIModel, Tokenizer
|
| 13 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 14 |
from graphgen.utils import set_logger
|
| 15 |
+
from webui.base import WebuiParams
|
| 16 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
| 17 |
from webui.count_tokens import count_tokens
|
| 18 |
from webui.i18n import Translate
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
# pylint: disable=too-many-statements
|
| 69 |
+
def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
| 70 |
def sum_tokens(client):
|
| 71 |
return sum(u["total_tokens"] for u in client.token_usage)
|
| 72 |
|
| 73 |
config = {
|
| 74 |
"if_trainee_model": params.if_trainee_model,
|
| 75 |
+
"read": {
|
| 76 |
+
"input_file": params.input_file,
|
| 77 |
+
},
|
| 78 |
+
"split": {
|
| 79 |
+
"chunk_size": params.chunk_size,
|
| 80 |
+
"chunk_overlap": params.chunk_overlap,
|
| 81 |
+
},
|
| 82 |
"output_data_type": params.output_data_type,
|
| 83 |
"output_data_format": params.output_data_format,
|
| 84 |
"tokenizer": params.tokenizer,
|
|
|
|
| 97 |
"isolated_node_strategy": params.isolated_node_strategy,
|
| 98 |
"loss_strategy": params.loss_strategy,
|
| 99 |
},
|
|
|
|
| 100 |
}
|
| 101 |
|
| 102 |
env = {
|
|
|
|
| 289 |
label="Chunk Size",
|
| 290 |
minimum=256,
|
| 291 |
maximum=4096,
|
| 292 |
+
value=1024,
|
| 293 |
step=256,
|
| 294 |
interactive=True,
|
| 295 |
)
|
| 296 |
+
chunk_overlap = gr.Slider(
|
| 297 |
+
label="Chunk Overlap",
|
| 298 |
+
minimum=0,
|
| 299 |
+
maximum=500,
|
| 300 |
+
value=100,
|
| 301 |
+
step=100,
|
| 302 |
+
interactive=True,
|
| 303 |
+
)
|
| 304 |
tokenizer = gr.Textbox(
|
| 305 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 306 |
)
|
|
|
|
| 512 |
|
| 513 |
submit_btn.click(
|
| 514 |
lambda *args: run_graphgen(
|
| 515 |
+
WebuiParams(
|
| 516 |
if_trainee_model=args[0],
|
| 517 |
input_file=args[1],
|
| 518 |
tokenizer=args[2],
|
|
|
|
| 531 |
trainee_model=args[15],
|
| 532 |
api_key=args[16],
|
| 533 |
chunk_size=args[17],
|
| 534 |
+
chunk_overlap=args[18],
|
| 535 |
+
rpm=args[19],
|
| 536 |
+
tpm=args[20],
|
| 537 |
+
quiz_samples=args[21],
|
| 538 |
+
trainee_url=args[22],
|
| 539 |
+
trainee_api_key=args[23],
|
| 540 |
+
token_counter=args[24],
|
| 541 |
)
|
| 542 |
),
|
| 543 |
inputs=[
|
|
|
|
| 559 |
trainee_model,
|
| 560 |
api_key,
|
| 561 |
chunk_size,
|
| 562 |
+
chunk_overlap,
|
| 563 |
rpm,
|
| 564 |
tpm,
|
| 565 |
quiz_samples,
|
graphgen/bases/base_splitter.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import re
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Callable, Iterable, List, Literal, Optional, Union
|
| 6 |
+
|
| 7 |
+
from graphgen.bases.datatypes import Chunk
|
| 8 |
+
from graphgen.utils import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class BaseSplitter(ABC):
|
| 13 |
+
"""
|
| 14 |
+
Abstract base class for splitting text into smaller chunks.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
chunk_size: int = 1024
|
| 18 |
+
chunk_overlap: int = 100
|
| 19 |
+
length_function: Callable[[str], int] = len
|
| 20 |
+
keep_separator: bool = False
|
| 21 |
+
add_start_index: bool = False
|
| 22 |
+
strip_whitespace: bool = True
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def split_text(self, text: str) -> List[str]:
|
| 26 |
+
"""
|
| 27 |
+
Split the input text into smaller chunks.
|
| 28 |
+
|
| 29 |
+
:param text: The input text to be split.
|
| 30 |
+
:return: A list of text chunks.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def create_chunks(
|
| 34 |
+
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
| 35 |
+
) -> List[Chunk]:
|
| 36 |
+
"""Create chunks from a list of texts."""
|
| 37 |
+
_metadatas = metadatas or [{}] * len(texts)
|
| 38 |
+
chunks = []
|
| 39 |
+
for i, text in enumerate(texts):
|
| 40 |
+
index = 0
|
| 41 |
+
previous_chunk_len = 0
|
| 42 |
+
for chunk in self.split_text(text):
|
| 43 |
+
metadata = copy.deepcopy(_metadatas[i])
|
| 44 |
+
if self.add_start_index:
|
| 45 |
+
offset = index + previous_chunk_len - self.chunk_overlap
|
| 46 |
+
index = text.find(chunk, max(0, offset))
|
| 47 |
+
metadata["start_index"] = index
|
| 48 |
+
previous_chunk_len = len(chunk)
|
| 49 |
+
new_chunk = Chunk(content=chunk, metadata=metadata)
|
| 50 |
+
chunks.append(new_chunk)
|
| 51 |
+
return chunks
|
| 52 |
+
|
| 53 |
+
def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
|
| 54 |
+
text = separator.join(chunks)
|
| 55 |
+
if self.strip_whitespace:
|
| 56 |
+
text = text.strip()
|
| 57 |
+
if text == "":
|
| 58 |
+
return None
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
| 62 |
+
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
|
| 63 |
+
separator_len = self.length_function(separator)
|
| 64 |
+
|
| 65 |
+
chunks = []
|
| 66 |
+
current_chunk: List[str] = []
|
| 67 |
+
total = 0
|
| 68 |
+
for d in splits:
|
| 69 |
+
_len = self.length_function(d)
|
| 70 |
+
if (
|
| 71 |
+
total + _len + (separator_len if len(current_chunk) > 0 else 0)
|
| 72 |
+
> self.chunk_size
|
| 73 |
+
):
|
| 74 |
+
if total > self.chunk_size:
|
| 75 |
+
logger.warning(
|
| 76 |
+
"Created a chunk of size %s, which is longer than the specified %s",
|
| 77 |
+
total,
|
| 78 |
+
self.chunk_size,
|
| 79 |
+
)
|
| 80 |
+
if len(current_chunk) > 0:
|
| 81 |
+
chunk = self._join_chunks(current_chunk, separator)
|
| 82 |
+
if chunk is not None:
|
| 83 |
+
chunks.append(chunk)
|
| 84 |
+
# Keep on popping if:
|
| 85 |
+
# - we have a larger chunk than in the chunk overlap
|
| 86 |
+
# - or if we still have any chunks and the length is long
|
| 87 |
+
while total > self.chunk_overlap or (
|
| 88 |
+
total + _len + (separator_len if len(current_chunk) > 0 else 0)
|
| 89 |
+
> self.chunk_size
|
| 90 |
+
and total > 0
|
| 91 |
+
):
|
| 92 |
+
total -= self.length_function(current_chunk[0]) + (
|
| 93 |
+
separator_len if len(current_chunk) > 1 else 0
|
| 94 |
+
)
|
| 95 |
+
current_chunk = current_chunk[1:]
|
| 96 |
+
current_chunk.append(d)
|
| 97 |
+
total += _len + (separator_len if len(current_chunk) > 1 else 0)
|
| 98 |
+
chunk = self._join_chunks(current_chunk, separator)
|
| 99 |
+
if chunk is not None:
|
| 100 |
+
chunks.append(chunk)
|
| 101 |
+
return chunks
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def _split_text_with_regex(
|
| 105 |
+
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
|
| 106 |
+
) -> List[str]:
|
| 107 |
+
# Now that we have the separator, split the text
|
| 108 |
+
if separator:
|
| 109 |
+
if keep_separator:
|
| 110 |
+
# The parentheses in the pattern keep the delimiters in the result.
|
| 111 |
+
_splits = re.split(f"({separator})", text)
|
| 112 |
+
splits = (
|
| 113 |
+
(
|
| 114 |
+
[
|
| 115 |
+
_splits[i] + _splits[i + 1]
|
| 116 |
+
for i in range(0, len(_splits) - 1, 2)
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
if keep_separator == "end"
|
| 120 |
+
else (
|
| 121 |
+
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
if len(_splits) % 2 == 0:
|
| 125 |
+
splits += _splits[-1:]
|
| 126 |
+
splits = (
|
| 127 |
+
(splits + [_splits[-1]])
|
| 128 |
+
if keep_separator == "end"
|
| 129 |
+
else ([_splits[0]] + splits)
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
splits = re.split(separator, text)
|
| 133 |
+
else:
|
| 134 |
+
splits = list(text)
|
| 135 |
+
return [s for s in splits if s != ""]
|
graphgen/bases/datatypes.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class Chunk:
|
| 6 |
+
id: str
|
| 7 |
+
content: str
|
| 8 |
+
metadata: dict = field(default_factory=dict)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class QAPair:
|
| 13 |
+
"""
|
| 14 |
+
A pair of question and answer.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
question: str
|
| 18 |
+
answer: str
|
graphgen/{models/text → configs}/__init__.py
RENAMED
|
File without changes
|
graphgen/configs/aggregated_config.yaml
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
|
| 3 |
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 4 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
|
|
|
| 1 |
+
read:
|
| 2 |
+
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
| 3 |
+
split:
|
| 4 |
+
chunk_size: 1024 # chunk size for text splitting
|
| 5 |
+
chunk_overlap: 100 # chunk overlap for text splitting
|
| 6 |
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
|
| 7 |
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 8 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
graphgen/configs/atomic_config.yaml
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
output_data_type: atomic # atomic, aggregated, multi_hop, cot
|
| 3 |
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
| 4 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
|
|
|
| 1 |
+
read:
|
| 2 |
+
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
|
| 3 |
+
split:
|
| 4 |
+
chunk_size: 1024 # chunk size for text splitting
|
| 5 |
+
chunk_overlap: 100 # chunk overlap for text splitting
|
| 6 |
output_data_type: atomic # atomic, aggregated, multi_hop, cot
|
| 7 |
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
| 8 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
graphgen/configs/cot_config.yaml
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
output_data_type: cot # atomic, aggregated, multi_hop, cot
|
| 3 |
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
| 4 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
|
|
|
| 1 |
+
read:
|
| 2 |
+
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
| 3 |
+
split:
|
| 4 |
+
chunk_size: 1024 # chunk size for text splitting
|
| 5 |
+
chunk_overlap: 100 # chunk overlap for text splitting
|
| 6 |
output_data_type: cot # atomic, aggregated, multi_hop, cot
|
| 7 |
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
| 8 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
graphgen/configs/multi_hop_config.yaml
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
|
| 3 |
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 4 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
|
|
|
| 1 |
+
read:
|
| 2 |
+
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
| 3 |
+
split:
|
| 4 |
+
chunk_size: 1024 # chunk size for text splitting
|
| 5 |
+
chunk_overlap: 100 # chunk overlap for text splitting
|
| 6 |
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
|
| 7 |
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
| 8 |
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
graphgen/evaluate.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
| 1 |
"""Evaluate the quality of the generated text using various metrics"""
|
| 2 |
|
| 3 |
-
import os
|
| 4 |
-
import json
|
| 5 |
import argparse
|
|
|
|
|
|
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
from .utils import logger, set_logger
|
| 10 |
|
| 11 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
|
@@ -13,15 +17,15 @@ set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
|
|
| 13 |
|
| 14 |
load_dotenv()
|
| 15 |
|
|
|
|
| 16 |
def evaluate_length(corpus, tokenizer_name):
|
| 17 |
-
length_evaluator = LengthEvaluator(
|
| 18 |
-
tokenizer_name=tokenizer_name
|
| 19 |
-
)
|
| 20 |
logger.info("Length evaluator loaded")
|
| 21 |
scores = length_evaluator.get_average_score(corpus)
|
| 22 |
logger.info("Length scores: %s", scores)
|
| 23 |
return scores
|
| 24 |
|
|
|
|
| 25 |
def evaluate_mtld(corpus):
|
| 26 |
mtld_evaluator = MTLDEvaluator()
|
| 27 |
logger.info("MTLD evaluator loaded")
|
|
@@ -31,30 +35,30 @@ def evaluate_mtld(corpus):
|
|
| 31 |
logger.info("MTLD min max scores: %s", min_max_scores)
|
| 32 |
return scores, min_max_scores
|
| 33 |
|
|
|
|
| 34 |
def evaluate_reward(corpus, reward_model_names):
|
| 35 |
scores = []
|
| 36 |
for reward_name in reward_model_names:
|
| 37 |
-
reward_evaluator = RewardEvaluator(
|
| 38 |
-
reward_name=reward_name
|
| 39 |
-
)
|
| 40 |
logger.info("Loaded reward model: %s", reward_name)
|
| 41 |
average_score = reward_evaluator.get_average_score(corpus)
|
| 42 |
logger.info("%s scores: %s", reward_name, average_score)
|
| 43 |
min_max_scores = reward_evaluator.get_min_max_score(corpus)
|
| 44 |
logger.info("%s min max scores: %s", reward_name, min_max_scores)
|
| 45 |
-
scores.append(
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
del reward_evaluator
|
| 51 |
clean_gpu_cache()
|
| 52 |
return scores
|
| 53 |
|
|
|
|
| 54 |
def evaluate_uni(corpus, uni_model_name):
|
| 55 |
-
uni_evaluator = UniEvaluator(
|
| 56 |
-
model_name=uni_model_name
|
| 57 |
-
)
|
| 58 |
logger.info("Uni evaluator loaded with model %s", uni_model_name)
|
| 59 |
uni_scores = uni_evaluator.get_average_score(corpus)
|
| 60 |
for key, value in uni_scores.items():
|
|
@@ -64,27 +68,47 @@ def evaluate_uni(corpus, uni_model_name):
|
|
| 64 |
logger.info("Uni %s min max scores: %s", key, value)
|
| 65 |
del uni_evaluator
|
| 66 |
clean_gpu_cache()
|
| 67 |
-
return (
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def clean_gpu_cache():
|
| 72 |
import torch
|
|
|
|
| 73 |
if torch.cuda.is_available():
|
| 74 |
torch.cuda.empty_cache()
|
| 75 |
|
| 76 |
|
| 77 |
-
if __name__ ==
|
| 78 |
import torch.multiprocessing as mp
|
|
|
|
| 79 |
parser = argparse.ArgumentParser()
|
| 80 |
|
| 81 |
-
parser.add_argument(
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
parser.add_argument(
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
args = parser.parse_args()
|
| 90 |
|
|
@@ -94,49 +118,55 @@ if __name__ == '__main__':
|
|
| 94 |
if not os.path.exists(args.output):
|
| 95 |
os.makedirs(args.output)
|
| 96 |
|
| 97 |
-
reward_models = args.reward.split(
|
| 98 |
-
|
| 99 |
|
| 100 |
results = []
|
| 101 |
|
| 102 |
logger.info("Data loaded from %s", args.folder)
|
| 103 |
-
mp.set_start_method(
|
| 104 |
|
| 105 |
for file in os.listdir(args.folder):
|
| 106 |
-
if file.endswith(
|
| 107 |
logger.info("Processing %s", file)
|
| 108 |
-
with open(os.path.join(args.folder, file),
|
| 109 |
data = json.load(f)
|
| 110 |
-
data = [
|
| 111 |
-
question=data[key][
|
| 112 |
-
|
| 113 |
-
|
| 114 |
|
| 115 |
length_scores = evaluate_length(data, args.tokenizer)
|
| 116 |
mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
|
| 117 |
reward_scores = evaluate_reward(data, reward_models)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
result = {
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
}
|
| 135 |
for reward_score in reward_scores:
|
| 136 |
-
result[reward_score[
|
| 137 |
-
result[f"{reward_score['reward_name']}_min_max"] = reward_score[
|
|
|
|
|
|
|
| 138 |
|
| 139 |
results.append(result)
|
| 140 |
|
| 141 |
results = pd.DataFrame(results)
|
| 142 |
-
results.to_csv(os.path.join(args.output,
|
|
|
|
| 1 |
"""Evaluate the quality of the generated text using various metrics"""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
import pandas as pd
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
from graphgen.bases.datatypes import QAPair
|
| 11 |
+
|
| 12 |
+
from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
|
| 13 |
from .utils import logger, set_logger
|
| 14 |
|
| 15 |
sys_path = os.path.abspath(os.path.dirname(__file__))
|
|
|
|
| 17 |
|
| 18 |
load_dotenv()
|
| 19 |
|
| 20 |
+
|
| 21 |
def evaluate_length(corpus, tokenizer_name):
|
| 22 |
+
length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name)
|
|
|
|
|
|
|
| 23 |
logger.info("Length evaluator loaded")
|
| 24 |
scores = length_evaluator.get_average_score(corpus)
|
| 25 |
logger.info("Length scores: %s", scores)
|
| 26 |
return scores
|
| 27 |
|
| 28 |
+
|
| 29 |
def evaluate_mtld(corpus):
|
| 30 |
mtld_evaluator = MTLDEvaluator()
|
| 31 |
logger.info("MTLD evaluator loaded")
|
|
|
|
| 35 |
logger.info("MTLD min max scores: %s", min_max_scores)
|
| 36 |
return scores, min_max_scores
|
| 37 |
|
| 38 |
+
|
| 39 |
def evaluate_reward(corpus, reward_model_names):
|
| 40 |
scores = []
|
| 41 |
for reward_name in reward_model_names:
|
| 42 |
+
reward_evaluator = RewardEvaluator(reward_name=reward_name)
|
|
|
|
|
|
|
| 43 |
logger.info("Loaded reward model: %s", reward_name)
|
| 44 |
average_score = reward_evaluator.get_average_score(corpus)
|
| 45 |
logger.info("%s scores: %s", reward_name, average_score)
|
| 46 |
min_max_scores = reward_evaluator.get_min_max_score(corpus)
|
| 47 |
logger.info("%s min max scores: %s", reward_name, min_max_scores)
|
| 48 |
+
scores.append(
|
| 49 |
+
{
|
| 50 |
+
"reward_name": reward_name.split("/")[-1],
|
| 51 |
+
"score": average_score,
|
| 52 |
+
"min_max_scores": min_max_scores,
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
del reward_evaluator
|
| 56 |
clean_gpu_cache()
|
| 57 |
return scores
|
| 58 |
|
| 59 |
+
|
| 60 |
def evaluate_uni(corpus, uni_model_name):
|
| 61 |
+
uni_evaluator = UniEvaluator(model_name=uni_model_name)
|
|
|
|
|
|
|
| 62 |
logger.info("Uni evaluator loaded with model %s", uni_model_name)
|
| 63 |
uni_scores = uni_evaluator.get_average_score(corpus)
|
| 64 |
for key, value in uni_scores.items():
|
|
|
|
| 68 |
logger.info("Uni %s min max scores: %s", key, value)
|
| 69 |
del uni_evaluator
|
| 70 |
clean_gpu_cache()
|
| 71 |
+
return (
|
| 72 |
+
uni_scores["naturalness"],
|
| 73 |
+
uni_scores["coherence"],
|
| 74 |
+
uni_scores["understandability"],
|
| 75 |
+
min_max_scores["naturalness"],
|
| 76 |
+
min_max_scores["coherence"],
|
| 77 |
+
min_max_scores["understandability"],
|
| 78 |
+
)
|
| 79 |
|
| 80 |
|
| 81 |
def clean_gpu_cache():
|
| 82 |
import torch
|
| 83 |
+
|
| 84 |
if torch.cuda.is_available():
|
| 85 |
torch.cuda.empty_cache()
|
| 86 |
|
| 87 |
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
import torch.multiprocessing as mp
|
| 90 |
+
|
| 91 |
parser = argparse.ArgumentParser()
|
| 92 |
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--folder", type=str, default="cache/data", help="folder to load data"
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--output", type=str, default="cache/output", help="path to save output"
|
| 98 |
+
)
|
| 99 |
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--tokenizer", type=str, default="cl100k_base", help="tokenizer name"
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--reward",
|
| 105 |
+
type=str,
|
| 106 |
+
default="OpenAssistant/reward-model-deberta-v3-large-v2",
|
| 107 |
+
help="Comma-separated list of reward models",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--uni", type=str, default="MingZhong/unieval-sum", help="uni model name"
|
| 111 |
+
)
|
| 112 |
|
| 113 |
args = parser.parse_args()
|
| 114 |
|
|
|
|
| 118 |
if not os.path.exists(args.output):
|
| 119 |
os.makedirs(args.output)
|
| 120 |
|
| 121 |
+
reward_models = args.reward.split(",")
|
|
|
|
| 122 |
|
| 123 |
results = []
|
| 124 |
|
| 125 |
logger.info("Data loaded from %s", args.folder)
|
| 126 |
+
mp.set_start_method("spawn")
|
| 127 |
|
| 128 |
for file in os.listdir(args.folder):
|
| 129 |
+
if file.endswith(".json"):
|
| 130 |
logger.info("Processing %s", file)
|
| 131 |
+
with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f:
|
| 132 |
data = json.load(f)
|
| 133 |
+
data = [
|
| 134 |
+
QAPair(question=data[key]["question"], answer=data[key]["answer"])
|
| 135 |
+
for key in data
|
| 136 |
+
]
|
| 137 |
|
| 138 |
length_scores = evaluate_length(data, args.tokenizer)
|
| 139 |
mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
|
| 140 |
reward_scores = evaluate_reward(data, reward_models)
|
| 141 |
+
(
|
| 142 |
+
uni_naturalness_scores,
|
| 143 |
+
uni_coherence_scores,
|
| 144 |
+
uni_understandability_scores,
|
| 145 |
+
min_max_uni_naturalness_scores,
|
| 146 |
+
min_max_uni_coherence_scores,
|
| 147 |
+
min_max_uni_understandability_scores,
|
| 148 |
+
) = evaluate_uni(data, args.uni)
|
| 149 |
|
| 150 |
result = {
|
| 151 |
+
"file": file,
|
| 152 |
+
"number": len(data),
|
| 153 |
+
"length": length_scores,
|
| 154 |
+
"mtld": mtld_scores,
|
| 155 |
+
"mtld_min_max": min_max_mtld_scores,
|
| 156 |
+
"uni_naturalness": uni_naturalness_scores,
|
| 157 |
+
"uni_coherence": uni_coherence_scores,
|
| 158 |
+
"uni_understandability": uni_understandability_scores,
|
| 159 |
+
"uni_naturalness_min_max": min_max_uni_naturalness_scores,
|
| 160 |
+
"uni_coherence_min_max": min_max_uni_coherence_scores,
|
| 161 |
+
"uni_understandability_min_max": min_max_uni_understandability_scores,
|
| 162 |
}
|
| 163 |
for reward_score in reward_scores:
|
| 164 |
+
result[reward_score["reward_name"]] = reward_score["score"]
|
| 165 |
+
result[f"{reward_score['reward_name']}_min_max"] = reward_score[
|
| 166 |
+
"min_max_scores"
|
| 167 |
+
]
|
| 168 |
|
| 169 |
results.append(result)
|
| 170 |
|
| 171 |
results = pd.DataFrame(results)
|
| 172 |
+
results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False)
|
graphgen/graphgen.py
CHANGED
|
@@ -8,8 +8,8 @@ import gradio as gr
|
|
| 8 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 9 |
|
| 10 |
from graphgen.bases.base_storage import StorageNameSpace
|
|
|
|
| 11 |
from graphgen.models import (
|
| 12 |
-
Chunk,
|
| 13 |
JsonKVStorage,
|
| 14 |
JsonListStorage,
|
| 15 |
NetworkXStorage,
|
|
@@ -17,6 +17,7 @@ from graphgen.models import (
|
|
| 17 |
Tokenizer,
|
| 18 |
TraverseStrategy,
|
| 19 |
read_file,
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
from .operators import (
|
|
@@ -32,6 +33,7 @@ from .operators import (
|
|
| 32 |
from .utils import (
|
| 33 |
compute_content_hash,
|
| 34 |
create_event_loop,
|
|
|
|
| 35 |
format_generation_results,
|
| 36 |
logger,
|
| 37 |
)
|
|
@@ -50,11 +52,6 @@ class GraphGen:
|
|
| 50 |
synthesizer_llm_client: OpenAIModel = None
|
| 51 |
trainee_llm_client: OpenAIModel = None
|
| 52 |
|
| 53 |
-
# text chunking
|
| 54 |
-
# TODO: make it configurable
|
| 55 |
-
chunk_size: int = 1024
|
| 56 |
-
chunk_overlap_size: int = 100
|
| 57 |
-
|
| 58 |
# search
|
| 59 |
search_config: dict = field(
|
| 60 |
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
|
|
@@ -136,14 +133,22 @@ class GraphGen:
|
|
| 136 |
async for doc_key, doc in tqdm_async(
|
| 137 |
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
| 138 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
chunks = {
|
| 140 |
-
compute_content_hash(
|
| 141 |
-
|
| 142 |
"full_doc_id": doc_key,
|
|
|
|
|
|
|
| 143 |
}
|
| 144 |
-
for
|
| 145 |
-
doc["content"], self.chunk_overlap_size, self.chunk_size
|
| 146 |
-
)
|
| 147 |
}
|
| 148 |
inserting_chunks.update(chunks)
|
| 149 |
|
|
@@ -171,7 +176,7 @@ class GraphGen:
|
|
| 171 |
insert chunks into the graph
|
| 172 |
"""
|
| 173 |
|
| 174 |
-
input_file = self.config["input_file"]
|
| 175 |
data = read_file(input_file)
|
| 176 |
inserting_chunks = await self.async_split_chunks(data)
|
| 177 |
|
|
|
|
| 8 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 9 |
|
| 10 |
from graphgen.bases.base_storage import StorageNameSpace
|
| 11 |
+
from graphgen.bases.datatypes import Chunk
|
| 12 |
from graphgen.models import (
|
|
|
|
| 13 |
JsonKVStorage,
|
| 14 |
JsonListStorage,
|
| 15 |
NetworkXStorage,
|
|
|
|
| 17 |
Tokenizer,
|
| 18 |
TraverseStrategy,
|
| 19 |
read_file,
|
| 20 |
+
split_chunks,
|
| 21 |
)
|
| 22 |
|
| 23 |
from .operators import (
|
|
|
|
| 33 |
from .utils import (
|
| 34 |
compute_content_hash,
|
| 35 |
create_event_loop,
|
| 36 |
+
detect_main_language,
|
| 37 |
format_generation_results,
|
| 38 |
logger,
|
| 39 |
)
|
|
|
|
| 52 |
synthesizer_llm_client: OpenAIModel = None
|
| 53 |
trainee_llm_client: OpenAIModel = None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# search
|
| 56 |
search_config: dict = field(
|
| 57 |
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
|
|
|
|
| 133 |
async for doc_key, doc in tqdm_async(
|
| 134 |
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
| 135 |
):
|
| 136 |
+
doc_language = detect_main_language(doc["content"])
|
| 137 |
+
text_chunks = split_chunks(
|
| 138 |
+
doc["content"],
|
| 139 |
+
language=doc_language,
|
| 140 |
+
chunk_size=self.config["split"]["chunk_size"],
|
| 141 |
+
chunk_overlap=self.config["split"]["chunk_overlap"],
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
chunks = {
|
| 145 |
+
compute_content_hash(txt, prefix="chunk-"): {
|
| 146 |
+
"content": txt,
|
| 147 |
"full_doc_id": doc_key,
|
| 148 |
+
"length": len(self.tokenizer_instance.encode_string(txt)),
|
| 149 |
+
"language": doc_language,
|
| 150 |
}
|
| 151 |
+
for txt in text_chunks
|
|
|
|
|
|
|
| 152 |
}
|
| 153 |
inserting_chunks.update(chunks)
|
| 154 |
|
|
|
|
| 176 |
insert chunks into the graph
|
| 177 |
"""
|
| 178 |
|
| 179 |
+
input_file = self.config["read"]["input_file"]
|
| 180 |
data = read_file(input_file)
|
| 181 |
inserting_chunks = await self.async_split_chunks(data)
|
| 182 |
|
graphgen/models/__init__.py
CHANGED
|
@@ -11,36 +11,7 @@ from .search.db.uniprot_search import UniProtSearch
|
|
| 11 |
from .search.kg.wiki_search import WikiSearch
|
| 12 |
from .search.web.bing_search import BingSearch
|
| 13 |
from .search.web.google_search import GoogleSearch
|
|
|
|
| 14 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
| 15 |
from .storage.networkx_storage import NetworkXStorage
|
| 16 |
from .strategy.travserse_strategy import TraverseStrategy
|
| 17 |
-
from .text.chunk import Chunk
|
| 18 |
-
from .text.text_pair import TextPair
|
| 19 |
-
|
| 20 |
-
__all__ = [
|
| 21 |
-
# llm models
|
| 22 |
-
"OpenAIModel",
|
| 23 |
-
"TopkTokenModel",
|
| 24 |
-
"Token",
|
| 25 |
-
"Tokenizer",
|
| 26 |
-
# storage models
|
| 27 |
-
"Chunk",
|
| 28 |
-
"NetworkXStorage",
|
| 29 |
-
"JsonKVStorage",
|
| 30 |
-
"JsonListStorage",
|
| 31 |
-
# search models
|
| 32 |
-
"WikiSearch",
|
| 33 |
-
"GoogleSearch",
|
| 34 |
-
"BingSearch",
|
| 35 |
-
"UniProtSearch",
|
| 36 |
-
# evaluate models
|
| 37 |
-
"TextPair",
|
| 38 |
-
"LengthEvaluator",
|
| 39 |
-
"MTLDEvaluator",
|
| 40 |
-
"RewardEvaluator",
|
| 41 |
-
"UniEvaluator",
|
| 42 |
-
# strategy models
|
| 43 |
-
"TraverseStrategy",
|
| 44 |
-
# community models
|
| 45 |
-
"CommunityDetector",
|
| 46 |
-
]
|
|
|
|
| 11 |
from .search.kg.wiki_search import WikiSearch
|
| 12 |
from .search.web.bing_search import BingSearch
|
| 13 |
from .search.web.google_search import GoogleSearch
|
| 14 |
+
from .splitter import split_chunks
|
| 15 |
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
| 16 |
from .storage.networkx_storage import NetworkXStorage
|
| 17 |
from .strategy.travserse_strategy import TraverseStrategy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluate/base_evaluator.py
CHANGED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
|
|
|
|
|
| 5 |
from graphgen.utils import create_event_loop
|
| 6 |
-
|
| 7 |
|
| 8 |
@dataclass
|
| 9 |
class BaseEvaluator:
|
| 10 |
max_concurrent: int = 100
|
| 11 |
results: list[float] = None
|
| 12 |
|
| 13 |
-
def evaluate(self, pairs: list[
|
| 14 |
"""
|
| 15 |
Evaluate the text and return a score.
|
| 16 |
"""
|
| 17 |
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
|
| 18 |
|
| 19 |
-
async def async_evaluate(self, pairs: list[
|
| 20 |
semaphore = asyncio.Semaphore(self.max_concurrent)
|
| 21 |
|
| 22 |
async def evaluate_with_semaphore(pair):
|
|
@@ -31,10 +33,10 @@ class BaseEvaluator:
|
|
| 31 |
results.append(await result)
|
| 32 |
return results
|
| 33 |
|
| 34 |
-
async def evaluate_single(self, pair:
|
| 35 |
raise NotImplementedError()
|
| 36 |
|
| 37 |
-
def get_average_score(self, pairs: list[
|
| 38 |
"""
|
| 39 |
Get the average score of a batch of texts.
|
| 40 |
"""
|
|
@@ -42,7 +44,7 @@ class BaseEvaluator:
|
|
| 42 |
self.results = results
|
| 43 |
return sum(self.results) / len(pairs)
|
| 44 |
|
| 45 |
-
def get_min_max_score(self, pairs: list[
|
| 46 |
"""
|
| 47 |
Get the min and max score of a batch of texts.
|
| 48 |
"""
|
|
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
+
|
| 6 |
+
from graphgen.bases.datatypes import QAPair
|
| 7 |
from graphgen.utils import create_event_loop
|
| 8 |
+
|
| 9 |
|
| 10 |
@dataclass
|
| 11 |
class BaseEvaluator:
|
| 12 |
max_concurrent: int = 100
|
| 13 |
results: list[float] = None
|
| 14 |
|
| 15 |
+
def evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 16 |
"""
|
| 17 |
Evaluate the text and return a score.
|
| 18 |
"""
|
| 19 |
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
|
| 20 |
|
| 21 |
+
async def async_evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 22 |
semaphore = asyncio.Semaphore(self.max_concurrent)
|
| 23 |
|
| 24 |
async def evaluate_with_semaphore(pair):
|
|
|
|
| 33 |
results.append(await result)
|
| 34 |
return results
|
| 35 |
|
| 36 |
+
async def evaluate_single(self, pair: QAPair) -> float:
|
| 37 |
raise NotImplementedError()
|
| 38 |
|
| 39 |
+
def get_average_score(self, pairs: list[QAPair]) -> float:
|
| 40 |
"""
|
| 41 |
Get the average score of a batch of texts.
|
| 42 |
"""
|
|
|
|
| 44 |
self.results = results
|
| 45 |
return sum(self.results) / len(pairs)
|
| 46 |
|
| 47 |
+
def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
|
| 48 |
"""
|
| 49 |
Get the min and max score of a batch of texts.
|
| 50 |
"""
|
graphgen/models/evaluate/length_evaluator.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
| 1 |
from dataclasses import dataclass
|
|
|
|
|
|
|
| 2 |
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
| 3 |
from graphgen.models.llm.tokenizer import Tokenizer
|
| 4 |
-
from graphgen.models.text.text_pair import TextPair
|
| 5 |
from graphgen.utils import create_event_loop
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass
|
| 9 |
class LengthEvaluator(BaseEvaluator):
|
| 10 |
tokenizer_name: str = "cl100k_base"
|
|
|
|
| 11 |
def __post_init__(self):
|
| 12 |
-
self.tokenizer = Tokenizer(
|
| 13 |
-
model_name=self.tokenizer_name
|
| 14 |
-
)
|
| 15 |
|
| 16 |
-
async def evaluate_single(self, pair:
|
| 17 |
loop = create_event_loop()
|
| 18 |
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
|
| 19 |
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from graphgen.bases.datatypes import QAPair
|
| 4 |
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
| 5 |
from graphgen.models.llm.tokenizer import Tokenizer
|
|
|
|
| 6 |
from graphgen.utils import create_event_loop
|
| 7 |
|
| 8 |
|
| 9 |
@dataclass
|
| 10 |
class LengthEvaluator(BaseEvaluator):
|
| 11 |
tokenizer_name: str = "cl100k_base"
|
| 12 |
+
|
| 13 |
def __post_init__(self):
|
| 14 |
+
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
async def evaluate_single(self, pair: QAPair) -> float:
|
| 17 |
loop = create_event_loop()
|
| 18 |
return await loop.run_in_executor(None, self._calculate_length, pair.answer)
|
| 19 |
|
graphgen/models/evaluate/mtld_evaluator.py
CHANGED
|
@@ -1,22 +1,27 @@
|
|
| 1 |
-
from dataclasses import
|
| 2 |
from typing import Set
|
| 3 |
|
|
|
|
| 4 |
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
| 5 |
-
from graphgen.
|
| 6 |
-
from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop
|
| 7 |
-
|
| 8 |
|
| 9 |
nltk_helper = NLTKHelper()
|
| 10 |
|
|
|
|
| 11 |
@dataclass
|
| 12 |
class MTLDEvaluator(BaseEvaluator):
|
| 13 |
"""
|
| 14 |
衡量文本词汇多样性的指标
|
| 15 |
"""
|
| 16 |
-
stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english")))
|
| 17 |
-
stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese")))
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
loop = create_event_loop()
|
| 21 |
return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
|
| 22 |
|
|
@@ -71,6 +76,6 @@ class MTLDEvaluator(BaseEvaluator):
|
|
| 71 |
if ttr <= threshold:
|
| 72 |
factors += 1
|
| 73 |
else:
|
| 74 |
-
factors +=
|
| 75 |
|
| 76 |
return len(tokens) / factors if factors > 0 else len(tokens)
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
from typing import Set
|
| 3 |
|
| 4 |
+
from graphgen.bases.datatypes import QAPair
|
| 5 |
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
|
| 6 |
+
from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
|
|
|
|
|
|
|
| 7 |
|
| 8 |
nltk_helper = NLTKHelper()
|
| 9 |
|
| 10 |
+
|
| 11 |
@dataclass
|
| 12 |
class MTLDEvaluator(BaseEvaluator):
|
| 13 |
"""
|
| 14 |
衡量文本词汇多样性的指标
|
| 15 |
"""
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
stopwords_en: Set[str] = field(
|
| 18 |
+
default_factory=lambda: set(nltk_helper.get_stopwords("english"))
|
| 19 |
+
)
|
| 20 |
+
stopwords_zh: Set[str] = field(
|
| 21 |
+
default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
async def evaluate_single(self, pair: QAPair) -> float:
|
| 25 |
loop = create_event_loop()
|
| 26 |
return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
|
| 27 |
|
|
|
|
| 76 |
if ttr <= threshold:
|
| 77 |
factors += 1
|
| 78 |
else:
|
| 79 |
+
factors += 1 - (ttr - threshold) / (1 - threshold)
|
| 80 |
|
| 81 |
return len(tokens) / factors if factors > 0 else len(tokens)
|
graphgen/models/evaluate/reward_evaluator.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from dataclasses import dataclass
|
|
|
|
| 2 |
from tqdm import tqdm
|
| 3 |
-
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
@dataclass
|
|
@@ -9,19 +11,22 @@ class RewardEvaluator:
|
|
| 9 |
Reward Model Evaluator.
|
| 10 |
OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
|
| 11 |
"""
|
|
|
|
| 12 |
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
| 13 |
max_length: int = 2560
|
| 14 |
results: list[float] = None
|
| 15 |
|
| 16 |
def __post_init__(self):
|
| 17 |
import torch
|
|
|
|
| 18 |
self.num_gpus = torch.cuda.device_count()
|
| 19 |
|
| 20 |
@staticmethod
|
| 21 |
def process_chunk(rank, pairs, reward_name, max_length, return_dict):
|
| 22 |
import torch
|
| 23 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 24 |
-
|
|
|
|
| 25 |
torch.cuda.set_device(rank)
|
| 26 |
|
| 27 |
rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
|
|
@@ -37,7 +42,7 @@ class RewardEvaluator:
|
|
| 37 |
pair.answer,
|
| 38 |
return_tensors="pt",
|
| 39 |
max_length=max_length,
|
| 40 |
-
truncation=True
|
| 41 |
)
|
| 42 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 43 |
score = rank_model(**inputs).logits[0].item()
|
|
@@ -45,8 +50,9 @@ class RewardEvaluator:
|
|
| 45 |
|
| 46 |
return_dict[rank] = results
|
| 47 |
|
| 48 |
-
def evaluate(self, pairs: list[
|
| 49 |
import torch.multiprocessing as mp
|
|
|
|
| 50 |
chunk_size = len(pairs) // self.num_gpus
|
| 51 |
chunks = []
|
| 52 |
for i in range(self.num_gpus):
|
|
@@ -64,7 +70,7 @@ class RewardEvaluator:
|
|
| 64 |
for rank, chunk in enumerate(chunks):
|
| 65 |
p = mp.Process(
|
| 66 |
target=self.process_chunk,
|
| 67 |
-
args=(rank, chunk, self.reward_name, self.max_length, return_dict)
|
| 68 |
)
|
| 69 |
p.start()
|
| 70 |
processes.append(p)
|
|
@@ -84,7 +90,7 @@ class RewardEvaluator:
|
|
| 84 |
|
| 85 |
return results
|
| 86 |
|
| 87 |
-
def get_average_score(self, pairs: list[
|
| 88 |
"""
|
| 89 |
Get the average score of a batch of texts.
|
| 90 |
"""
|
|
@@ -92,7 +98,7 @@ class RewardEvaluator:
|
|
| 92 |
self.results = results
|
| 93 |
return sum(self.results) / len(pairs)
|
| 94 |
|
| 95 |
-
def get_min_max_score(self, pairs: list[
|
| 96 |
"""
|
| 97 |
Get the min and max score of a batch of texts.
|
| 98 |
"""
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from graphgen.bases.datatypes import QAPair
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass
|
|
|
|
| 11 |
Reward Model Evaluator.
|
| 12 |
OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
|
| 13 |
"""
|
| 14 |
+
|
| 15 |
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
| 16 |
max_length: int = 2560
|
| 17 |
results: list[float] = None
|
| 18 |
|
| 19 |
def __post_init__(self):
|
| 20 |
import torch
|
| 21 |
+
|
| 22 |
self.num_gpus = torch.cuda.device_count()
|
| 23 |
|
| 24 |
@staticmethod
|
| 25 |
def process_chunk(rank, pairs, reward_name, max_length, return_dict):
|
| 26 |
import torch
|
| 27 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 28 |
+
|
| 29 |
+
device = f"cuda:{rank}"
|
| 30 |
torch.cuda.set_device(rank)
|
| 31 |
|
| 32 |
rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
|
|
|
|
| 42 |
pair.answer,
|
| 43 |
return_tensors="pt",
|
| 44 |
max_length=max_length,
|
| 45 |
+
truncation=True,
|
| 46 |
)
|
| 47 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 48 |
score = rank_model(**inputs).logits[0].item()
|
|
|
|
| 50 |
|
| 51 |
return_dict[rank] = results
|
| 52 |
|
| 53 |
+
def evaluate(self, pairs: list[QAPair]) -> list[float]:
|
| 54 |
import torch.multiprocessing as mp
|
| 55 |
+
|
| 56 |
chunk_size = len(pairs) // self.num_gpus
|
| 57 |
chunks = []
|
| 58 |
for i in range(self.num_gpus):
|
|
|
|
| 70 |
for rank, chunk in enumerate(chunks):
|
| 71 |
p = mp.Process(
|
| 72 |
target=self.process_chunk,
|
| 73 |
+
args=(rank, chunk, self.reward_name, self.max_length, return_dict),
|
| 74 |
)
|
| 75 |
p.start()
|
| 76 |
processes.append(p)
|
|
|
|
| 90 |
|
| 91 |
return results
|
| 92 |
|
| 93 |
+
def get_average_score(self, pairs: list[QAPair]) -> float:
|
| 94 |
"""
|
| 95 |
Get the average score of a batch of texts.
|
| 96 |
"""
|
|
|
|
| 98 |
self.results = results
|
| 99 |
return sum(self.results) / len(pairs)
|
| 100 |
|
| 101 |
+
def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
|
| 102 |
"""
|
| 103 |
Get the min and max score of a batch of texts.
|
| 104 |
"""
|
graphgen/models/evaluate/uni_evaluator.py
CHANGED
|
@@ -1,40 +1,58 @@
|
|
| 1 |
# https://github.com/maszhongming/UniEval/tree/main
|
| 2 |
|
| 3 |
from dataclasses import dataclass, field
|
|
|
|
| 4 |
from tqdm import tqdm
|
| 5 |
-
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def _add_questions(dimension: str, question: str, answer: str):
|
| 9 |
if dimension == "naturalness":
|
| 10 |
-
cur_input =
|
|
|
|
|
|
|
|
|
|
| 11 |
elif dimension == "coherence":
|
| 12 |
-
cur_input =
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
elif dimension == "understandability":
|
| 15 |
-
cur_input =
|
|
|
|
|
|
|
|
|
|
| 16 |
else:
|
| 17 |
raise NotImplementedError(
|
| 18 |
-
|
|
|
|
| 19 |
return cur_input
|
| 20 |
|
|
|
|
| 21 |
@dataclass
|
| 22 |
class UniEvaluator:
|
| 23 |
model_name: str = "MingZhong/unieval-sum"
|
| 24 |
-
dimensions: list = field(
|
|
|
|
|
|
|
| 25 |
max_length: int = 2560
|
| 26 |
results: dict = None
|
| 27 |
|
| 28 |
def __post_init__(self):
|
| 29 |
import torch
|
|
|
|
| 30 |
self.num_gpus = torch.cuda.device_count()
|
| 31 |
self.results = {}
|
| 32 |
|
| 33 |
@staticmethod
|
| 34 |
def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
|
| 35 |
import torch
|
| 36 |
-
from transformers import
|
| 37 |
-
|
|
|
|
| 38 |
torch.cuda.set_device(rank)
|
| 39 |
|
| 40 |
rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
@@ -59,26 +77,26 @@ class UniEvaluator:
|
|
| 59 |
max_length=max_length,
|
| 60 |
truncation=True,
|
| 61 |
padding=True,
|
| 62 |
-
return_tensors=
|
| 63 |
)
|
| 64 |
encoded_tgt = tokenizer(
|
| 65 |
tgt,
|
| 66 |
max_length=max_length,
|
| 67 |
truncation=True,
|
| 68 |
padding=True,
|
| 69 |
-
return_tensors=
|
| 70 |
)
|
| 71 |
|
| 72 |
-
src_tokens = encoded_src[
|
| 73 |
-
src_mask = encoded_src[
|
| 74 |
|
| 75 |
-
tgt_tokens = encoded_tgt[
|
| 76 |
|
| 77 |
output = rank_model(
|
| 78 |
input_ids=src_tokens,
|
| 79 |
attention_mask=src_mask,
|
| 80 |
labels=tgt_tokens,
|
| 81 |
-
use_cache
|
| 82 |
)
|
| 83 |
|
| 84 |
logits = output.logits.view(-1, rank_model.config.vocab_size)
|
|
@@ -91,8 +109,9 @@ class UniEvaluator:
|
|
| 91 |
|
| 92 |
return_dict[rank] = results
|
| 93 |
|
| 94 |
-
def evaluate(self, pairs: list[
|
| 95 |
import torch.multiprocessing as mp
|
|
|
|
| 96 |
final_results = []
|
| 97 |
for dimension in self.dimensions:
|
| 98 |
chunk_size = len(pairs) // self.num_gpus
|
|
@@ -112,7 +131,14 @@ class UniEvaluator:
|
|
| 112 |
for rank, chunk in enumerate(chunks):
|
| 113 |
p = mp.Process(
|
| 114 |
target=self.process_chunk,
|
| 115 |
-
args=(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
p.start()
|
| 118 |
processes.append(p)
|
|
@@ -130,12 +156,10 @@ class UniEvaluator:
|
|
| 130 |
p.terminate()
|
| 131 |
p.join()
|
| 132 |
|
| 133 |
-
final_results.append({
|
| 134 |
-
dimension: results
|
| 135 |
-
})
|
| 136 |
return final_results
|
| 137 |
|
| 138 |
-
def get_average_score(self, pairs: list[
|
| 139 |
"""
|
| 140 |
Get the average score of a batch of texts.
|
| 141 |
"""
|
|
@@ -147,7 +171,7 @@ class UniEvaluator:
|
|
| 147 |
self.results[key] = value
|
| 148 |
return final_results
|
| 149 |
|
| 150 |
-
def get_min_max_score(self, pairs: list[
|
| 151 |
"""
|
| 152 |
Get the min and max score of a batch of texts.
|
| 153 |
"""
|
|
|
|
| 1 |
# https://github.com/maszhongming/UniEval/tree/main
|
| 2 |
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from graphgen.bases.datatypes import QAPair
|
| 8 |
|
| 9 |
|
| 10 |
def _add_questions(dimension: str, question: str, answer: str):
|
| 11 |
if dimension == "naturalness":
|
| 12 |
+
cur_input = (
|
| 13 |
+
"question: Is this a natural response in the dialogue? </s> response: "
|
| 14 |
+
+ answer
|
| 15 |
+
)
|
| 16 |
elif dimension == "coherence":
|
| 17 |
+
cur_input = (
|
| 18 |
+
"question: Is this a coherent response given the dialogue history? </s> response: "
|
| 19 |
+
+ answer
|
| 20 |
+
+ " </s> dialogue history: "
|
| 21 |
+
+ question
|
| 22 |
+
)
|
| 23 |
elif dimension == "understandability":
|
| 24 |
+
cur_input = (
|
| 25 |
+
"question: Is this an understandable response in the dialogue? </s> response: "
|
| 26 |
+
+ answer
|
| 27 |
+
)
|
| 28 |
else:
|
| 29 |
raise NotImplementedError(
|
| 30 |
+
"The input format for this dimension is still undefined. Please customize it first."
|
| 31 |
+
)
|
| 32 |
return cur_input
|
| 33 |
|
| 34 |
+
|
| 35 |
@dataclass
|
| 36 |
class UniEvaluator:
|
| 37 |
model_name: str = "MingZhong/unieval-sum"
|
| 38 |
+
dimensions: list = field(
|
| 39 |
+
default_factory=lambda: ["naturalness", "coherence", "understandability"]
|
| 40 |
+
)
|
| 41 |
max_length: int = 2560
|
| 42 |
results: dict = None
|
| 43 |
|
| 44 |
def __post_init__(self):
|
| 45 |
import torch
|
| 46 |
+
|
| 47 |
self.num_gpus = torch.cuda.device_count()
|
| 48 |
self.results = {}
|
| 49 |
|
| 50 |
@staticmethod
|
| 51 |
def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
|
| 52 |
import torch
|
| 53 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 54 |
+
|
| 55 |
+
device = f"cuda:{rank}"
|
| 56 |
torch.cuda.set_device(rank)
|
| 57 |
|
| 58 |
rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
|
|
| 77 |
max_length=max_length,
|
| 78 |
truncation=True,
|
| 79 |
padding=True,
|
| 80 |
+
return_tensors="pt",
|
| 81 |
)
|
| 82 |
encoded_tgt = tokenizer(
|
| 83 |
tgt,
|
| 84 |
max_length=max_length,
|
| 85 |
truncation=True,
|
| 86 |
padding=True,
|
| 87 |
+
return_tensors="pt",
|
| 88 |
)
|
| 89 |
|
| 90 |
+
src_tokens = encoded_src["input_ids"].to(device)
|
| 91 |
+
src_mask = encoded_src["attention_mask"].to(device)
|
| 92 |
|
| 93 |
+
tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1)
|
| 94 |
|
| 95 |
output = rank_model(
|
| 96 |
input_ids=src_tokens,
|
| 97 |
attention_mask=src_mask,
|
| 98 |
labels=tgt_tokens,
|
| 99 |
+
use_cache=False,
|
| 100 |
)
|
| 101 |
|
| 102 |
logits = output.logits.view(-1, rank_model.config.vocab_size)
|
|
|
|
| 109 |
|
| 110 |
return_dict[rank] = results
|
| 111 |
|
| 112 |
+
def evaluate(self, pairs: list[QAPair]) -> list[dict]:
|
| 113 |
import torch.multiprocessing as mp
|
| 114 |
+
|
| 115 |
final_results = []
|
| 116 |
for dimension in self.dimensions:
|
| 117 |
chunk_size = len(pairs) // self.num_gpus
|
|
|
|
| 131 |
for rank, chunk in enumerate(chunks):
|
| 132 |
p = mp.Process(
|
| 133 |
target=self.process_chunk,
|
| 134 |
+
args=(
|
| 135 |
+
rank,
|
| 136 |
+
chunk,
|
| 137 |
+
self.model_name,
|
| 138 |
+
self.max_length,
|
| 139 |
+
dimension,
|
| 140 |
+
return_dict,
|
| 141 |
+
),
|
| 142 |
)
|
| 143 |
p.start()
|
| 144 |
processes.append(p)
|
|
|
|
| 156 |
p.terminate()
|
| 157 |
p.join()
|
| 158 |
|
| 159 |
+
final_results.append({dimension: results})
|
|
|
|
|
|
|
| 160 |
return final_results
|
| 161 |
|
| 162 |
+
def get_average_score(self, pairs: list[QAPair]) -> dict:
|
| 163 |
"""
|
| 164 |
Get the average score of a batch of texts.
|
| 165 |
"""
|
|
|
|
| 171 |
self.results[key] = value
|
| 172 |
return final_results
|
| 173 |
|
| 174 |
+
def get_min_max_score(self, pairs: list[QAPair]) -> dict:
|
| 175 |
"""
|
| 176 |
Get the min and max score of a batch of texts.
|
| 177 |
"""
|
graphgen/models/splitter/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
from .recursive_character_splitter import (
|
| 5 |
+
ChineseRecursiveTextSplitter,
|
| 6 |
+
RecursiveCharacterSplitter,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
_MAPPING = {
|
| 10 |
+
"en": RecursiveCharacterSplitter,
|
| 11 |
+
"zh": ChineseRecursiveTextSplitter,
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@lru_cache(maxsize=None)
|
| 18 |
+
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
|
| 19 |
+
cls = _MAPPING[language]
|
| 20 |
+
kwargs = dict(frozen_kwargs)
|
| 21 |
+
return cls(**kwargs)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
|
| 25 |
+
if language not in _MAPPING:
|
| 26 |
+
raise ValueError(
|
| 27 |
+
f"Unsupported language: {language}. "
|
| 28 |
+
f"Supported languages are: {list(_MAPPING.keys())}"
|
| 29 |
+
)
|
| 30 |
+
splitter = _get_splitter(language, frozenset(kwargs.items()))
|
| 31 |
+
return splitter.split_text(text)
|
graphgen/models/splitter/character_splitter.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Any, List
|
| 3 |
+
|
| 4 |
+
from graphgen.bases.base_splitter import BaseSplitter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CharacterSplitter(BaseSplitter):
|
| 8 |
+
"""Splitting text that looks at characters."""
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
|
| 12 |
+
) -> None:
|
| 13 |
+
"""Create a new TextSplitter."""
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
self._separator = separator
|
| 16 |
+
self._is_separator_regex = is_separator_regex
|
| 17 |
+
|
| 18 |
+
def split_text(self, text: str) -> List[str]:
|
| 19 |
+
"""Split incoming text and return chunks."""
|
| 20 |
+
# First we naively split the large input into a bunch of smaller ones.
|
| 21 |
+
separator = (
|
| 22 |
+
self._separator if self._is_separator_regex else re.escape(self._separator)
|
| 23 |
+
)
|
| 24 |
+
splits = self._split_text_with_regex(text, separator, self.keep_separator)
|
| 25 |
+
_separator = "" if self.keep_separator else self._separator
|
| 26 |
+
return self._merge_splits(splits, _separator)
|
graphgen/models/splitter/markdown_splitter.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from graphgen.models.splitter.recursive_character_splitter import (
|
| 4 |
+
RecursiveCharacterSplitter,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MarkdownTextRefSplitter(RecursiveCharacterSplitter):
|
| 9 |
+
"""Attempts to split the text along Markdown-formatted headings."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 12 |
+
"""Initialize a MarkdownTextRefSplitter."""
|
| 13 |
+
separators = [
|
| 14 |
+
# First, try to split along Markdown headings (starting with level 2)
|
| 15 |
+
"\n#{1,6} ",
|
| 16 |
+
# Note the alternative syntax for headings (below) is not handled here
|
| 17 |
+
# Heading level 2
|
| 18 |
+
# ---------------
|
| 19 |
+
# End of code block
|
| 20 |
+
"```\n",
|
| 21 |
+
# Horizontal lines
|
| 22 |
+
"\n\\*\\*\\*+\n",
|
| 23 |
+
"\n---+\n",
|
| 24 |
+
"\n___+\n",
|
| 25 |
+
# Note: horizontal lines defined by three or more of ***, ---, or ___
|
| 26 |
+
# are handled by the regexes above, but alternative syntaxes (e.g., with spaces)
|
| 27 |
+
# are not handled.
|
| 28 |
+
"\n\n",
|
| 29 |
+
"\n",
|
| 30 |
+
" ",
|
| 31 |
+
"",
|
| 32 |
+
]
|
| 33 |
+
super().__init__(separators=separators, **kwargs)
|
graphgen/models/splitter/recursive_character_splitter.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Any, List, Optional
|
| 3 |
+
|
| 4 |
+
from graphgen.bases.base_splitter import BaseSplitter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RecursiveCharacterSplitter(BaseSplitter):
|
| 8 |
+
"""Splitting text by recursively look at characters.
|
| 9 |
+
|
| 10 |
+
Recursively tries to split by different characters to find one that works.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
separators: Optional[List[str]] = None,
|
| 16 |
+
keep_separator: bool = True,
|
| 17 |
+
is_separator_regex: bool = False,
|
| 18 |
+
**kwargs: Any,
|
| 19 |
+
) -> None:
|
| 20 |
+
"""Create a new TextSplitter."""
|
| 21 |
+
super().__init__(keep_separator=keep_separator, **kwargs)
|
| 22 |
+
self._separators = separators or ["\n\n", "\n", " ", ""]
|
| 23 |
+
self._is_separator_regex = is_separator_regex
|
| 24 |
+
|
| 25 |
+
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
| 26 |
+
"""Split incoming text and return chunks."""
|
| 27 |
+
final_chunks = []
|
| 28 |
+
# Get appropriate separator to use
|
| 29 |
+
separator = separators[-1]
|
| 30 |
+
new_separators = []
|
| 31 |
+
for i, _s in enumerate(separators):
|
| 32 |
+
_separator = _s if self._is_separator_regex else re.escape(_s)
|
| 33 |
+
if _s == "":
|
| 34 |
+
separator = _s
|
| 35 |
+
break
|
| 36 |
+
if re.search(_separator, text):
|
| 37 |
+
separator = _s
|
| 38 |
+
new_separators = separators[i + 1 :]
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
_separator = separator if self._is_separator_regex else re.escape(separator)
|
| 42 |
+
splits = self._split_text_with_regex(text, _separator, self.keep_separator)
|
| 43 |
+
|
| 44 |
+
# Now go merging things, recursively splitting longer texts.
|
| 45 |
+
_good_splits = []
|
| 46 |
+
_separator = "" if self.keep_separator else separator
|
| 47 |
+
for s in splits:
|
| 48 |
+
if self.length_function(s) < self.chunk_size:
|
| 49 |
+
_good_splits.append(s)
|
| 50 |
+
else:
|
| 51 |
+
if _good_splits:
|
| 52 |
+
merged_text = self._merge_splits(_good_splits, _separator)
|
| 53 |
+
final_chunks.extend(merged_text)
|
| 54 |
+
_good_splits = []
|
| 55 |
+
if not new_separators:
|
| 56 |
+
final_chunks.append(s)
|
| 57 |
+
else:
|
| 58 |
+
other_info = self._split_text(s, new_separators)
|
| 59 |
+
final_chunks.extend(other_info)
|
| 60 |
+
if _good_splits:
|
| 61 |
+
merged_text = self._merge_splits(_good_splits, _separator)
|
| 62 |
+
final_chunks.extend(merged_text)
|
| 63 |
+
return final_chunks
|
| 64 |
+
|
| 65 |
+
def split_text(self, text: str) -> List[str]:
|
| 66 |
+
return self._split_text(text, self._separators)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ChineseRecursiveTextSplitter(RecursiveCharacterSplitter):
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
separators: Optional[List[str]] = None,
|
| 73 |
+
keep_separator: bool = True,
|
| 74 |
+
is_separator_regex: bool = True,
|
| 75 |
+
**kwargs: Any,
|
| 76 |
+
) -> None:
|
| 77 |
+
super().__init__(keep_separator=keep_separator, **kwargs)
|
| 78 |
+
self._separators = separators or [
|
| 79 |
+
"\n\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"。|!|?",
|
| 82 |
+
r"\.\s|\!\s|\?\s",
|
| 83 |
+
r";|;\s",
|
| 84 |
+
r",|,\s",
|
| 85 |
+
]
|
| 86 |
+
self._is_separator_regex = is_separator_regex
|
| 87 |
+
|
| 88 |
+
def _split_text_with_regex_from_end(
|
| 89 |
+
self, text: str, separator: str, keep_separator: bool
|
| 90 |
+
) -> List[str]:
|
| 91 |
+
# Now that we have the separator, split the text
|
| 92 |
+
if separator:
|
| 93 |
+
if keep_separator:
|
| 94 |
+
# The parentheses in the pattern keep the delimiters in the result.
|
| 95 |
+
_splits = re.split(f"({separator})", text)
|
| 96 |
+
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
|
| 97 |
+
if len(_splits) % 2 == 1:
|
| 98 |
+
splits += _splits[-1:]
|
| 99 |
+
else:
|
| 100 |
+
splits = re.split(separator, text)
|
| 101 |
+
else:
|
| 102 |
+
splits = list(text)
|
| 103 |
+
return [s for s in splits if s != ""]
|
| 104 |
+
|
| 105 |
+
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
| 106 |
+
"""Split incoming text and return chunks."""
|
| 107 |
+
final_chunks = []
|
| 108 |
+
# Get appropriate separator to use
|
| 109 |
+
separator = separators[-1]
|
| 110 |
+
new_separators = []
|
| 111 |
+
for i, _s in enumerate(separators):
|
| 112 |
+
_separator = _s if self._is_separator_regex else re.escape(_s)
|
| 113 |
+
if _s == "":
|
| 114 |
+
separator = _s
|
| 115 |
+
break
|
| 116 |
+
if re.search(_separator, text):
|
| 117 |
+
separator = _s
|
| 118 |
+
new_separators = separators[i + 1 :]
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
_separator = separator if self._is_separator_regex else re.escape(separator)
|
| 122 |
+
splits = self._split_text_with_regex_from_end(
|
| 123 |
+
text, _separator, self.keep_separator
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Now go merging things, recursively splitting longer texts.
|
| 127 |
+
_good_splits = []
|
| 128 |
+
_separator = "" if self.keep_separator else separator
|
| 129 |
+
for s in splits:
|
| 130 |
+
if self.length_function(s) < self.chunk_size:
|
| 131 |
+
_good_splits.append(s)
|
| 132 |
+
else:
|
| 133 |
+
if _good_splits:
|
| 134 |
+
merged_text = self._merge_splits(_good_splits, _separator)
|
| 135 |
+
final_chunks.extend(merged_text)
|
| 136 |
+
_good_splits = []
|
| 137 |
+
if not new_separators:
|
| 138 |
+
final_chunks.append(s)
|
| 139 |
+
else:
|
| 140 |
+
other_info = self._split_text(s, new_separators)
|
| 141 |
+
final_chunks.extend(other_info)
|
| 142 |
+
if _good_splits:
|
| 143 |
+
merged_text = self._merge_splits(_good_splits, _separator)
|
| 144 |
+
final_chunks.extend(merged_text)
|
| 145 |
+
return [
|
| 146 |
+
re.sub(r"\n{2,}", "\n", chunk.strip())
|
| 147 |
+
for chunk in final_chunks
|
| 148 |
+
if chunk.strip() != ""
|
| 149 |
+
]
|
graphgen/models/text/chunk.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
@dataclass
|
| 5 |
-
class Chunk:
|
| 6 |
-
id : str
|
| 7 |
-
content: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/text/text_pair.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
@dataclass
|
| 4 |
-
class TextPair:
|
| 5 |
-
"""
|
| 6 |
-
A pair of input data.
|
| 7 |
-
"""
|
| 8 |
-
question: str
|
| 9 |
-
answer: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/kg/extract_kg.py
CHANGED
|
@@ -7,7 +7,8 @@ import gradio as gr
|
|
| 7 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 8 |
|
| 9 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 10 |
-
from graphgen.
|
|
|
|
| 11 |
from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
|
| 12 |
from graphgen.templates import KG_EXTRACTION_PROMPT
|
| 13 |
from graphgen.utils import (
|
|
|
|
| 7 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 8 |
|
| 9 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 10 |
+
from graphgen.bases.datatypes import Chunk
|
| 11 |
+
from graphgen.models import OpenAIModel, Tokenizer
|
| 12 |
from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
|
| 13 |
from graphgen.templates import KG_EXTRACTION_PROMPT
|
| 14 |
from graphgen.utils import (
|
graphgen/operators/preprocess/resolute_coreference.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
-
from graphgen.
|
|
|
|
| 4 |
from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
|
| 5 |
from graphgen.utils import detect_main_language
|
| 6 |
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
+
from graphgen.bases.datatypes import Chunk
|
| 4 |
+
from graphgen.models import OpenAIModel
|
| 5 |
from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
|
| 6 |
from graphgen.utils import detect_main_language
|
| 7 |
|
webui/app.py
CHANGED
|
@@ -12,7 +12,7 @@ from graphgen.graphgen import GraphGen
|
|
| 12 |
from graphgen.models import OpenAIModel, Tokenizer
|
| 13 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 14 |
from graphgen.utils import set_logger
|
| 15 |
-
from webui.base import
|
| 16 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
| 17 |
from webui.count_tokens import count_tokens
|
| 18 |
from webui.i18n import Translate
|
|
@@ -66,13 +66,19 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 66 |
|
| 67 |
|
| 68 |
# pylint: disable=too-many-statements
|
| 69 |
-
def run_graphgen(params, progress=gr.Progress()):
|
| 70 |
def sum_tokens(client):
|
| 71 |
return sum(u["total_tokens"] for u in client.token_usage)
|
| 72 |
|
| 73 |
config = {
|
| 74 |
"if_trainee_model": params.if_trainee_model,
|
| 75 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
"output_data_type": params.output_data_type,
|
| 77 |
"output_data_format": params.output_data_format,
|
| 78 |
"tokenizer": params.tokenizer,
|
|
@@ -91,7 +97,6 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 91 |
"isolated_node_strategy": params.isolated_node_strategy,
|
| 92 |
"loss_strategy": params.loss_strategy,
|
| 93 |
},
|
| 94 |
-
"chunk_size": params.chunk_size,
|
| 95 |
}
|
| 96 |
|
| 97 |
env = {
|
|
@@ -284,10 +289,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 284 |
label="Chunk Size",
|
| 285 |
minimum=256,
|
| 286 |
maximum=4096,
|
| 287 |
-
value=
|
| 288 |
step=256,
|
| 289 |
interactive=True,
|
| 290 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
tokenizer = gr.Textbox(
|
| 292 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 293 |
)
|
|
@@ -499,7 +512,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 499 |
|
| 500 |
submit_btn.click(
|
| 501 |
lambda *args: run_graphgen(
|
| 502 |
-
|
| 503 |
if_trainee_model=args[0],
|
| 504 |
input_file=args[1],
|
| 505 |
tokenizer=args[2],
|
|
@@ -518,12 +531,13 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 518 |
trainee_model=args[15],
|
| 519 |
api_key=args[16],
|
| 520 |
chunk_size=args[17],
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
|
|
|
| 527 |
)
|
| 528 |
),
|
| 529 |
inputs=[
|
|
@@ -545,6 +559,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 545 |
trainee_model,
|
| 546 |
api_key,
|
| 547 |
chunk_size,
|
|
|
|
| 548 |
rpm,
|
| 549 |
tpm,
|
| 550 |
quiz_samples,
|
|
|
|
| 12 |
from graphgen.models import OpenAIModel, Tokenizer
|
| 13 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 14 |
from graphgen.utils import set_logger
|
| 15 |
+
from webui.base import WebuiParams
|
| 16 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
| 17 |
from webui.count_tokens import count_tokens
|
| 18 |
from webui.i18n import Translate
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
# pylint: disable=too-many-statements
|
| 69 |
+
def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
| 70 |
def sum_tokens(client):
|
| 71 |
return sum(u["total_tokens"] for u in client.token_usage)
|
| 72 |
|
| 73 |
config = {
|
| 74 |
"if_trainee_model": params.if_trainee_model,
|
| 75 |
+
"read": {
|
| 76 |
+
"input_file": params.input_file,
|
| 77 |
+
},
|
| 78 |
+
"split": {
|
| 79 |
+
"chunk_size": params.chunk_size,
|
| 80 |
+
"chunk_overlap": params.chunk_overlap,
|
| 81 |
+
},
|
| 82 |
"output_data_type": params.output_data_type,
|
| 83 |
"output_data_format": params.output_data_format,
|
| 84 |
"tokenizer": params.tokenizer,
|
|
|
|
| 97 |
"isolated_node_strategy": params.isolated_node_strategy,
|
| 98 |
"loss_strategy": params.loss_strategy,
|
| 99 |
},
|
|
|
|
| 100 |
}
|
| 101 |
|
| 102 |
env = {
|
|
|
|
| 289 |
label="Chunk Size",
|
| 290 |
minimum=256,
|
| 291 |
maximum=4096,
|
| 292 |
+
value=1024,
|
| 293 |
step=256,
|
| 294 |
interactive=True,
|
| 295 |
)
|
| 296 |
+
chunk_overlap = gr.Slider(
|
| 297 |
+
label="Chunk Overlap",
|
| 298 |
+
minimum=0,
|
| 299 |
+
maximum=500,
|
| 300 |
+
value=100,
|
| 301 |
+
step=100,
|
| 302 |
+
interactive=True,
|
| 303 |
+
)
|
| 304 |
tokenizer = gr.Textbox(
|
| 305 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 306 |
)
|
|
|
|
| 512 |
|
| 513 |
submit_btn.click(
|
| 514 |
lambda *args: run_graphgen(
|
| 515 |
+
WebuiParams(
|
| 516 |
if_trainee_model=args[0],
|
| 517 |
input_file=args[1],
|
| 518 |
tokenizer=args[2],
|
|
|
|
| 531 |
trainee_model=args[15],
|
| 532 |
api_key=args[16],
|
| 533 |
chunk_size=args[17],
|
| 534 |
+
chunk_overlap=args[18],
|
| 535 |
+
rpm=args[19],
|
| 536 |
+
tpm=args[20],
|
| 537 |
+
quiz_samples=args[21],
|
| 538 |
+
trainee_url=args[22],
|
| 539 |
+
trainee_api_key=args[23],
|
| 540 |
+
token_counter=args[24],
|
| 541 |
)
|
| 542 |
),
|
| 543 |
inputs=[
|
|
|
|
| 559 |
trainee_model,
|
| 560 |
api_key,
|
| 561 |
chunk_size,
|
| 562 |
+
chunk_overlap,
|
| 563 |
rpm,
|
| 564 |
tpm,
|
| 565 |
quiz_samples,
|
webui/base.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
| 3 |
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
-
class
|
| 7 |
"""
|
| 8 |
GraphGen parameters
|
| 9 |
"""
|
|
@@ -26,6 +26,7 @@ class GraphGenParams:
|
|
| 26 |
trainee_model: str
|
| 27 |
api_key: str
|
| 28 |
chunk_size: int
|
|
|
|
| 29 |
rpm: int
|
| 30 |
tpm: int
|
| 31 |
quiz_samples: int
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
+
class WebuiParams:
|
| 7 |
"""
|
| 8 |
GraphGen parameters
|
| 9 |
"""
|
|
|
|
| 26 |
trainee_model: str
|
| 27 |
api_key: str
|
| 28 |
chunk_size: int
|
| 29 |
+
chunk_overlap: int
|
| 30 |
rpm: int
|
| 31 |
tpm: int
|
| 32 |
quiz_samples: int
|