github-actions[bot] commited on
Commit
43d27f2
·
1 Parent(s): d2a63cc

Auto-sync from demo at Wed Sep 24 09:52:41 UTC 2025

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ from graphgen.graphgen import GraphGen
12
  from graphgen.models import OpenAIModel, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
- from webui.base import GraphGenParams
16
  from webui.cache_utils import cleanup_workspace, setup_workspace
17
  from webui.count_tokens import count_tokens
18
  from webui.i18n import Translate
@@ -66,13 +66,19 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
66
 
67
 
68
  # pylint: disable=too-many-statements
69
- def run_graphgen(params, progress=gr.Progress()):
70
  def sum_tokens(client):
71
  return sum(u["total_tokens"] for u in client.token_usage)
72
 
73
  config = {
74
  "if_trainee_model": params.if_trainee_model,
75
- "input_file": params.input_file,
 
 
 
 
 
 
76
  "output_data_type": params.output_data_type,
77
  "output_data_format": params.output_data_format,
78
  "tokenizer": params.tokenizer,
@@ -91,7 +97,6 @@ def run_graphgen(params, progress=gr.Progress()):
91
  "isolated_node_strategy": params.isolated_node_strategy,
92
  "loss_strategy": params.loss_strategy,
93
  },
94
- "chunk_size": params.chunk_size,
95
  }
96
 
97
  env = {
@@ -284,10 +289,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
284
  label="Chunk Size",
285
  minimum=256,
286
  maximum=4096,
287
- value=512,
288
  step=256,
289
  interactive=True,
290
  )
 
 
 
 
 
 
 
 
291
  tokenizer = gr.Textbox(
292
  label="Tokenizer", value="cl100k_base", interactive=True
293
  )
@@ -499,7 +512,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
499
 
500
  submit_btn.click(
501
  lambda *args: run_graphgen(
502
- GraphGenParams(
503
  if_trainee_model=args[0],
504
  input_file=args[1],
505
  tokenizer=args[2],
@@ -518,12 +531,13 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
518
  trainee_model=args[15],
519
  api_key=args[16],
520
  chunk_size=args[17],
521
- rpm=args[18],
522
- tpm=args[19],
523
- quiz_samples=args[20],
524
- trainee_url=args[21],
525
- trainee_api_key=args[22],
526
- token_counter=args[23],
 
527
  )
528
  ),
529
  inputs=[
@@ -545,6 +559,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
545
  trainee_model,
546
  api_key,
547
  chunk_size,
 
548
  rpm,
549
  tpm,
550
  quiz_samples,
 
12
  from graphgen.models import OpenAIModel, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
+ from webui.base import WebuiParams
16
  from webui.cache_utils import cleanup_workspace, setup_workspace
17
  from webui.count_tokens import count_tokens
18
  from webui.i18n import Translate
 
66
 
67
 
68
  # pylint: disable=too-many-statements
69
+ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
70
  def sum_tokens(client):
71
  return sum(u["total_tokens"] for u in client.token_usage)
72
 
73
  config = {
74
  "if_trainee_model": params.if_trainee_model,
75
+ "read": {
76
+ "input_file": params.input_file,
77
+ },
78
+ "split": {
79
+ "chunk_size": params.chunk_size,
80
+ "chunk_overlap": params.chunk_overlap,
81
+ },
82
  "output_data_type": params.output_data_type,
83
  "output_data_format": params.output_data_format,
84
  "tokenizer": params.tokenizer,
 
97
  "isolated_node_strategy": params.isolated_node_strategy,
98
  "loss_strategy": params.loss_strategy,
99
  },
 
100
  }
101
 
102
  env = {
 
289
  label="Chunk Size",
290
  minimum=256,
291
  maximum=4096,
292
+ value=1024,
293
  step=256,
294
  interactive=True,
295
  )
296
+ chunk_overlap = gr.Slider(
297
+ label="Chunk Overlap",
298
+ minimum=0,
299
+ maximum=500,
300
+ value=100,
301
+ step=100,
302
+ interactive=True,
303
+ )
304
  tokenizer = gr.Textbox(
305
  label="Tokenizer", value="cl100k_base", interactive=True
306
  )
 
512
 
513
  submit_btn.click(
514
  lambda *args: run_graphgen(
515
+ WebuiParams(
516
  if_trainee_model=args[0],
517
  input_file=args[1],
518
  tokenizer=args[2],
 
531
  trainee_model=args[15],
532
  api_key=args[16],
533
  chunk_size=args[17],
534
+ chunk_overlap=args[18],
535
+ rpm=args[19],
536
+ tpm=args[20],
537
+ quiz_samples=args[21],
538
+ trainee_url=args[22],
539
+ trainee_api_key=args[23],
540
+ token_counter=args[24],
541
  )
542
  ),
543
  inputs=[
 
559
  trainee_model,
560
  api_key,
561
  chunk_size,
562
+ chunk_overlap,
563
  rpm,
564
  tpm,
565
  quiz_samples,
graphgen/bases/base_splitter.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Callable, Iterable, List, Literal, Optional, Union
6
+
7
+ from graphgen.bases.datatypes import Chunk
8
+ from graphgen.utils import logger
9
+
10
+
11
+ @dataclass
12
+ class BaseSplitter(ABC):
13
+ """
14
+ Abstract base class for splitting text into smaller chunks.
15
+ """
16
+
17
+ chunk_size: int = 1024
18
+ chunk_overlap: int = 100
19
+ length_function: Callable[[str], int] = len
20
+ keep_separator: bool = False
21
+ add_start_index: bool = False
22
+ strip_whitespace: bool = True
23
+
24
+ @abstractmethod
25
+ def split_text(self, text: str) -> List[str]:
26
+ """
27
+ Split the input text into smaller chunks.
28
+
29
+ :param text: The input text to be split.
30
+ :return: A list of text chunks.
31
+ """
32
+
33
+ def create_chunks(
34
+ self, texts: List[str], metadatas: Optional[List[dict]] = None
35
+ ) -> List[Chunk]:
36
+ """Create chunks from a list of texts."""
37
+ _metadatas = metadatas or [{}] * len(texts)
38
+ chunks = []
39
+ for i, text in enumerate(texts):
40
+ index = 0
41
+ previous_chunk_len = 0
42
+ for chunk in self.split_text(text):
43
+ metadata = copy.deepcopy(_metadatas[i])
44
+ if self.add_start_index:
45
+ offset = index + previous_chunk_len - self.chunk_overlap
46
+ index = text.find(chunk, max(0, offset))
47
+ metadata["start_index"] = index
48
+ previous_chunk_len = len(chunk)
49
+ new_chunk = Chunk(content=chunk, metadata=metadata)
50
+ chunks.append(new_chunk)
51
+ return chunks
52
+
53
+ def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
54
+ text = separator.join(chunks)
55
+ if self.strip_whitespace:
56
+ text = text.strip()
57
+ if text == "":
58
+ return None
59
+ return text
60
+
61
+ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
62
+ # We now want to combine these smaller pieces into medium size chunks to send to the LLM.
63
+ separator_len = self.length_function(separator)
64
+
65
+ chunks = []
66
+ current_chunk: List[str] = []
67
+ total = 0
68
+ for d in splits:
69
+ _len = self.length_function(d)
70
+ if (
71
+ total + _len + (separator_len if len(current_chunk) > 0 else 0)
72
+ > self.chunk_size
73
+ ):
74
+ if total > self.chunk_size:
75
+ logger.warning(
76
+ "Created a chunk of size %s, which is longer than the specified %s",
77
+ total,
78
+ self.chunk_size,
79
+ )
80
+ if len(current_chunk) > 0:
81
+ chunk = self._join_chunks(current_chunk, separator)
82
+ if chunk is not None:
83
+ chunks.append(chunk)
84
+ # Keep on popping if:
85
+ # - we have a larger chunk than in the chunk overlap
86
+ # - or if we still have any chunks and the length is long
87
+ while total > self.chunk_overlap or (
88
+ total + _len + (separator_len if len(current_chunk) > 0 else 0)
89
+ > self.chunk_size
90
+ and total > 0
91
+ ):
92
+ total -= self.length_function(current_chunk[0]) + (
93
+ separator_len if len(current_chunk) > 1 else 0
94
+ )
95
+ current_chunk = current_chunk[1:]
96
+ current_chunk.append(d)
97
+ total += _len + (separator_len if len(current_chunk) > 1 else 0)
98
+ chunk = self._join_chunks(current_chunk, separator)
99
+ if chunk is not None:
100
+ chunks.append(chunk)
101
+ return chunks
102
+
103
+ @staticmethod
104
+ def _split_text_with_regex(
105
+ text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
106
+ ) -> List[str]:
107
+ # Now that we have the separator, split the text
108
+ if separator:
109
+ if keep_separator:
110
+ # The parentheses in the pattern keep the delimiters in the result.
111
+ _splits = re.split(f"({separator})", text)
112
+ splits = (
113
+ (
114
+ [
115
+ _splits[i] + _splits[i + 1]
116
+ for i in range(0, len(_splits) - 1, 2)
117
+ ]
118
+ )
119
+ if keep_separator == "end"
120
+ else (
121
+ [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
122
+ )
123
+ )
124
+ if len(_splits) % 2 == 0:
125
+ splits += _splits[-1:]
126
+ splits = (
127
+ (splits + [_splits[-1]])
128
+ if keep_separator == "end"
129
+ else ([_splits[0]] + splits)
130
+ )
131
+ else:
132
+ splits = re.split(separator, text)
133
+ else:
134
+ splits = list(text)
135
+ return [s for s in splits if s != ""]
graphgen/bases/datatypes.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class Chunk:
6
+ id: str
7
+ content: str
8
+ metadata: dict = field(default_factory=dict)
9
+
10
+
11
+ @dataclass
12
+ class QAPair:
13
+ """
14
+ A pair of question and answer.
15
+ """
16
+
17
+ question: str
18
+ answer: str
graphgen/{models/text → configs}/__init__.py RENAMED
File without changes
graphgen/configs/aggregated_config.yaml CHANGED
@@ -1,4 +1,8 @@
1
- input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
 
 
 
 
2
  output_data_type: aggregated # atomic, aggregated, multi_hop, cot
3
  output_data_format: ChatML # Alpaca, Sharegpt, ChatML
4
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
 
1
+ read:
2
+ input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ split:
4
+ chunk_size: 1024 # chunk size for text splitting
5
+ chunk_overlap: 100 # chunk overlap for text splitting
6
  output_data_type: aggregated # atomic, aggregated, multi_hop, cot
7
  output_data_format: ChatML # Alpaca, Sharegpt, ChatML
8
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
graphgen/configs/atomic_config.yaml CHANGED
@@ -1,4 +1,8 @@
1
- input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
 
 
 
 
2
  output_data_type: atomic # atomic, aggregated, multi_hop, cot
3
  output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
4
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
 
1
+ read:
2
+ input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
3
+ split:
4
+ chunk_size: 1024 # chunk size for text splitting
5
+ chunk_overlap: 100 # chunk overlap for text splitting
6
  output_data_type: atomic # atomic, aggregated, multi_hop, cot
7
  output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
8
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
graphgen/configs/cot_config.yaml CHANGED
@@ -1,4 +1,8 @@
1
- input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
 
 
 
 
2
  output_data_type: cot # atomic, aggregated, multi_hop, cot
3
  output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
4
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
 
1
+ read:
2
+ input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ split:
4
+ chunk_size: 1024 # chunk size for text splitting
5
+ chunk_overlap: 100 # chunk overlap for text splitting
6
  output_data_type: cot # atomic, aggregated, multi_hop, cot
7
  output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
8
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
graphgen/configs/multi_hop_config.yaml CHANGED
@@ -1,4 +1,8 @@
1
- input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
 
 
 
 
2
  output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
3
  output_data_format: ChatML # Alpaca, Sharegpt, ChatML
4
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
 
1
+ read:
2
+ input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ split:
4
+ chunk_size: 1024 # chunk size for text splitting
5
+ chunk_overlap: 100 # chunk overlap for text splitting
6
  output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
7
  output_data_format: ChatML # Alpaca, Sharegpt, ChatML
8
  tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
graphgen/evaluate.py CHANGED
@@ -1,11 +1,15 @@
1
  """Evaluate the quality of the generated text using various metrics"""
2
 
3
- import os
4
- import json
5
  import argparse
 
 
 
6
  import pandas as pd
7
  from dotenv import load_dotenv
8
- from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator
 
 
 
9
  from .utils import logger, set_logger
10
 
11
  sys_path = os.path.abspath(os.path.dirname(__file__))
@@ -13,15 +17,15 @@ set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
13
 
14
  load_dotenv()
15
 
 
16
  def evaluate_length(corpus, tokenizer_name):
17
- length_evaluator = LengthEvaluator(
18
- tokenizer_name=tokenizer_name
19
- )
20
  logger.info("Length evaluator loaded")
21
  scores = length_evaluator.get_average_score(corpus)
22
  logger.info("Length scores: %s", scores)
23
  return scores
24
 
 
25
  def evaluate_mtld(corpus):
26
  mtld_evaluator = MTLDEvaluator()
27
  logger.info("MTLD evaluator loaded")
@@ -31,30 +35,30 @@ def evaluate_mtld(corpus):
31
  logger.info("MTLD min max scores: %s", min_max_scores)
32
  return scores, min_max_scores
33
 
 
34
  def evaluate_reward(corpus, reward_model_names):
35
  scores = []
36
  for reward_name in reward_model_names:
37
- reward_evaluator = RewardEvaluator(
38
- reward_name=reward_name
39
- )
40
  logger.info("Loaded reward model: %s", reward_name)
41
  average_score = reward_evaluator.get_average_score(corpus)
42
  logger.info("%s scores: %s", reward_name, average_score)
43
  min_max_scores = reward_evaluator.get_min_max_score(corpus)
44
  logger.info("%s min max scores: %s", reward_name, min_max_scores)
45
- scores.append({
46
- 'reward_name': reward_name.split('/')[-1],
47
- 'score': average_score,
48
- 'min_max_scores': min_max_scores
49
- })
 
 
50
  del reward_evaluator
51
  clean_gpu_cache()
52
  return scores
53
 
 
54
  def evaluate_uni(corpus, uni_model_name):
55
- uni_evaluator = UniEvaluator(
56
- model_name=uni_model_name
57
- )
58
  logger.info("Uni evaluator loaded with model %s", uni_model_name)
59
  uni_scores = uni_evaluator.get_average_score(corpus)
60
  for key, value in uni_scores.items():
@@ -64,27 +68,47 @@ def evaluate_uni(corpus, uni_model_name):
64
  logger.info("Uni %s min max scores: %s", key, value)
65
  del uni_evaluator
66
  clean_gpu_cache()
67
- return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
68
- min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
 
 
 
 
 
 
69
 
70
 
71
  def clean_gpu_cache():
72
  import torch
 
73
  if torch.cuda.is_available():
74
  torch.cuda.empty_cache()
75
 
76
 
77
- if __name__ == '__main__':
78
  import torch.multiprocessing as mp
 
79
  parser = argparse.ArgumentParser()
80
 
81
- parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
82
- parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
 
 
 
 
83
 
84
- parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
85
- parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
86
- help='Comma-separated list of reward models')
87
- parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
 
 
 
 
 
 
 
 
88
 
89
  args = parser.parse_args()
90
 
@@ -94,49 +118,55 @@ if __name__ == '__main__':
94
  if not os.path.exists(args.output):
95
  os.makedirs(args.output)
96
 
97
- reward_models = args.reward.split(',')
98
-
99
 
100
  results = []
101
 
102
  logger.info("Data loaded from %s", args.folder)
103
- mp.set_start_method('spawn')
104
 
105
  for file in os.listdir(args.folder):
106
- if file.endswith('.json'):
107
  logger.info("Processing %s", file)
108
- with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
109
  data = json.load(f)
110
- data = [TextPair(
111
- question=data[key]['question'],
112
- answer=data[key]['answer']
113
- ) for key in data]
114
 
115
  length_scores = evaluate_length(data, args.tokenizer)
116
  mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
117
  reward_scores = evaluate_reward(data, reward_models)
118
- uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
119
- min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
120
- = evaluate_uni(data, args.uni)
 
 
 
 
 
121
 
122
  result = {
123
- 'file': file,
124
- 'number': len(data),
125
- 'length': length_scores,
126
- 'mtld': mtld_scores,
127
- 'mtld_min_max': min_max_mtld_scores,
128
- 'uni_naturalness': uni_naturalness_scores,
129
- 'uni_coherence': uni_coherence_scores,
130
- 'uni_understandability': uni_understandability_scores,
131
- 'uni_naturalness_min_max': min_max_uni_naturalness_scores,
132
- 'uni_coherence_min_max': min_max_uni_coherence_scores,
133
- 'uni_understandability_min_max': min_max_uni_understandability_scores
134
  }
135
  for reward_score in reward_scores:
136
- result[reward_score['reward_name']] = reward_score['score']
137
- result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
 
 
138
 
139
  results.append(result)
140
 
141
  results = pd.DataFrame(results)
142
- results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
 
1
  """Evaluate the quality of the generated text using various metrics"""
2
 
 
 
3
  import argparse
4
+ import json
5
+ import os
6
+
7
  import pandas as pd
8
  from dotenv import load_dotenv
9
+
10
+ from graphgen.bases.datatypes import QAPair
11
+
12
+ from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
13
  from .utils import logger, set_logger
14
 
15
  sys_path = os.path.abspath(os.path.dirname(__file__))
 
17
 
18
  load_dotenv()
19
 
20
+
21
  def evaluate_length(corpus, tokenizer_name):
22
+ length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name)
 
 
23
  logger.info("Length evaluator loaded")
24
  scores = length_evaluator.get_average_score(corpus)
25
  logger.info("Length scores: %s", scores)
26
  return scores
27
 
28
+
29
  def evaluate_mtld(corpus):
30
  mtld_evaluator = MTLDEvaluator()
31
  logger.info("MTLD evaluator loaded")
 
35
  logger.info("MTLD min max scores: %s", min_max_scores)
36
  return scores, min_max_scores
37
 
38
+
39
  def evaluate_reward(corpus, reward_model_names):
40
  scores = []
41
  for reward_name in reward_model_names:
42
+ reward_evaluator = RewardEvaluator(reward_name=reward_name)
 
 
43
  logger.info("Loaded reward model: %s", reward_name)
44
  average_score = reward_evaluator.get_average_score(corpus)
45
  logger.info("%s scores: %s", reward_name, average_score)
46
  min_max_scores = reward_evaluator.get_min_max_score(corpus)
47
  logger.info("%s min max scores: %s", reward_name, min_max_scores)
48
+ scores.append(
49
+ {
50
+ "reward_name": reward_name.split("/")[-1],
51
+ "score": average_score,
52
+ "min_max_scores": min_max_scores,
53
+ }
54
+ )
55
  del reward_evaluator
56
  clean_gpu_cache()
57
  return scores
58
 
59
+
60
  def evaluate_uni(corpus, uni_model_name):
61
+ uni_evaluator = UniEvaluator(model_name=uni_model_name)
 
 
62
  logger.info("Uni evaluator loaded with model %s", uni_model_name)
63
  uni_scores = uni_evaluator.get_average_score(corpus)
64
  for key, value in uni_scores.items():
 
68
  logger.info("Uni %s min max scores: %s", key, value)
69
  del uni_evaluator
70
  clean_gpu_cache()
71
+ return (
72
+ uni_scores["naturalness"],
73
+ uni_scores["coherence"],
74
+ uni_scores["understandability"],
75
+ min_max_scores["naturalness"],
76
+ min_max_scores["coherence"],
77
+ min_max_scores["understandability"],
78
+ )
79
 
80
 
81
  def clean_gpu_cache():
82
  import torch
83
+
84
  if torch.cuda.is_available():
85
  torch.cuda.empty_cache()
86
 
87
 
88
+ if __name__ == "__main__":
89
  import torch.multiprocessing as mp
90
+
91
  parser = argparse.ArgumentParser()
92
 
93
+ parser.add_argument(
94
+ "--folder", type=str, default="cache/data", help="folder to load data"
95
+ )
96
+ parser.add_argument(
97
+ "--output", type=str, default="cache/output", help="path to save output"
98
+ )
99
 
100
+ parser.add_argument(
101
+ "--tokenizer", type=str, default="cl100k_base", help="tokenizer name"
102
+ )
103
+ parser.add_argument(
104
+ "--reward",
105
+ type=str,
106
+ default="OpenAssistant/reward-model-deberta-v3-large-v2",
107
+ help="Comma-separated list of reward models",
108
+ )
109
+ parser.add_argument(
110
+ "--uni", type=str, default="MingZhong/unieval-sum", help="uni model name"
111
+ )
112
 
113
  args = parser.parse_args()
114
 
 
118
  if not os.path.exists(args.output):
119
  os.makedirs(args.output)
120
 
121
+ reward_models = args.reward.split(",")
 
122
 
123
  results = []
124
 
125
  logger.info("Data loaded from %s", args.folder)
126
+ mp.set_start_method("spawn")
127
 
128
  for file in os.listdir(args.folder):
129
+ if file.endswith(".json"):
130
  logger.info("Processing %s", file)
131
+ with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f:
132
  data = json.load(f)
133
+ data = [
134
+ QAPair(question=data[key]["question"], answer=data[key]["answer"])
135
+ for key in data
136
+ ]
137
 
138
  length_scores = evaluate_length(data, args.tokenizer)
139
  mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
140
  reward_scores = evaluate_reward(data, reward_models)
141
+ (
142
+ uni_naturalness_scores,
143
+ uni_coherence_scores,
144
+ uni_understandability_scores,
145
+ min_max_uni_naturalness_scores,
146
+ min_max_uni_coherence_scores,
147
+ min_max_uni_understandability_scores,
148
+ ) = evaluate_uni(data, args.uni)
149
 
150
  result = {
151
+ "file": file,
152
+ "number": len(data),
153
+ "length": length_scores,
154
+ "mtld": mtld_scores,
155
+ "mtld_min_max": min_max_mtld_scores,
156
+ "uni_naturalness": uni_naturalness_scores,
157
+ "uni_coherence": uni_coherence_scores,
158
+ "uni_understandability": uni_understandability_scores,
159
+ "uni_naturalness_min_max": min_max_uni_naturalness_scores,
160
+ "uni_coherence_min_max": min_max_uni_coherence_scores,
161
+ "uni_understandability_min_max": min_max_uni_understandability_scores,
162
  }
163
  for reward_score in reward_scores:
164
+ result[reward_score["reward_name"]] = reward_score["score"]
165
+ result[f"{reward_score['reward_name']}_min_max"] = reward_score[
166
+ "min_max_scores"
167
+ ]
168
 
169
  results.append(result)
170
 
171
  results = pd.DataFrame(results)
172
+ results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False)
graphgen/graphgen.py CHANGED
@@ -8,8 +8,8 @@ import gradio as gr
8
  from tqdm.asyncio import tqdm as tqdm_async
9
 
10
  from graphgen.bases.base_storage import StorageNameSpace
 
11
  from graphgen.models import (
12
- Chunk,
13
  JsonKVStorage,
14
  JsonListStorage,
15
  NetworkXStorage,
@@ -17,6 +17,7 @@ from graphgen.models import (
17
  Tokenizer,
18
  TraverseStrategy,
19
  read_file,
 
20
  )
21
 
22
  from .operators import (
@@ -32,6 +33,7 @@ from .operators import (
32
  from .utils import (
33
  compute_content_hash,
34
  create_event_loop,
 
35
  format_generation_results,
36
  logger,
37
  )
@@ -50,11 +52,6 @@ class GraphGen:
50
  synthesizer_llm_client: OpenAIModel = None
51
  trainee_llm_client: OpenAIModel = None
52
 
53
- # text chunking
54
- # TODO: make it configurable
55
- chunk_size: int = 1024
56
- chunk_overlap_size: int = 100
57
-
58
  # search
59
  search_config: dict = field(
60
  default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
@@ -136,14 +133,22 @@ class GraphGen:
136
  async for doc_key, doc in tqdm_async(
137
  new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
138
  ):
 
 
 
 
 
 
 
 
139
  chunks = {
140
- compute_content_hash(dp["content"], prefix="chunk-"): {
141
- **dp,
142
  "full_doc_id": doc_key,
 
 
143
  }
144
- for dp in self.tokenizer_instance.chunk_by_token_size(
145
- doc["content"], self.chunk_overlap_size, self.chunk_size
146
- )
147
  }
148
  inserting_chunks.update(chunks)
149
 
@@ -171,7 +176,7 @@ class GraphGen:
171
  insert chunks into the graph
172
  """
173
 
174
- input_file = self.config["input_file"]
175
  data = read_file(input_file)
176
  inserting_chunks = await self.async_split_chunks(data)
177
 
 
8
  from tqdm.asyncio import tqdm as tqdm_async
9
 
10
  from graphgen.bases.base_storage import StorageNameSpace
11
+ from graphgen.bases.datatypes import Chunk
12
  from graphgen.models import (
 
13
  JsonKVStorage,
14
  JsonListStorage,
15
  NetworkXStorage,
 
17
  Tokenizer,
18
  TraverseStrategy,
19
  read_file,
20
+ split_chunks,
21
  )
22
 
23
  from .operators import (
 
33
  from .utils import (
34
  compute_content_hash,
35
  create_event_loop,
36
+ detect_main_language,
37
  format_generation_results,
38
  logger,
39
  )
 
52
  synthesizer_llm_client: OpenAIModel = None
53
  trainee_llm_client: OpenAIModel = None
54
 
 
 
 
 
 
55
  # search
56
  search_config: dict = field(
57
  default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
 
133
  async for doc_key, doc in tqdm_async(
134
  new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
135
  ):
136
+ doc_language = detect_main_language(doc["content"])
137
+ text_chunks = split_chunks(
138
+ doc["content"],
139
+ language=doc_language,
140
+ chunk_size=self.config["split"]["chunk_size"],
141
+ chunk_overlap=self.config["split"]["chunk_overlap"],
142
+ )
143
+
144
  chunks = {
145
+ compute_content_hash(txt, prefix="chunk-"): {
146
+ "content": txt,
147
  "full_doc_id": doc_key,
148
+ "length": len(self.tokenizer_instance.encode_string(txt)),
149
+ "language": doc_language,
150
  }
151
+ for txt in text_chunks
 
 
152
  }
153
  inserting_chunks.update(chunks)
154
 
 
176
  insert chunks into the graph
177
  """
178
 
179
+ input_file = self.config["read"]["input_file"]
180
  data = read_file(input_file)
181
  inserting_chunks = await self.async_split_chunks(data)
182
 
graphgen/models/__init__.py CHANGED
@@ -11,36 +11,7 @@ from .search.db.uniprot_search import UniProtSearch
11
  from .search.kg.wiki_search import WikiSearch
12
  from .search.web.bing_search import BingSearch
13
  from .search.web.google_search import GoogleSearch
 
14
  from .storage.json_storage import JsonKVStorage, JsonListStorage
15
  from .storage.networkx_storage import NetworkXStorage
16
  from .strategy.travserse_strategy import TraverseStrategy
17
- from .text.chunk import Chunk
18
- from .text.text_pair import TextPair
19
-
20
- __all__ = [
21
- # llm models
22
- "OpenAIModel",
23
- "TopkTokenModel",
24
- "Token",
25
- "Tokenizer",
26
- # storage models
27
- "Chunk",
28
- "NetworkXStorage",
29
- "JsonKVStorage",
30
- "JsonListStorage",
31
- # search models
32
- "WikiSearch",
33
- "GoogleSearch",
34
- "BingSearch",
35
- "UniProtSearch",
36
- # evaluate models
37
- "TextPair",
38
- "LengthEvaluator",
39
- "MTLDEvaluator",
40
- "RewardEvaluator",
41
- "UniEvaluator",
42
- # strategy models
43
- "TraverseStrategy",
44
- # community models
45
- "CommunityDetector",
46
- ]
 
11
  from .search.kg.wiki_search import WikiSearch
12
  from .search.web.bing_search import BingSearch
13
  from .search.web.google_search import GoogleSearch
14
+ from .splitter import split_chunks
15
  from .storage.json_storage import JsonKVStorage, JsonListStorage
16
  from .storage.networkx_storage import NetworkXStorage
17
  from .strategy.travserse_strategy import TraverseStrategy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/evaluate/base_evaluator.py CHANGED
@@ -1,22 +1,24 @@
1
  import asyncio
2
-
3
  from dataclasses import dataclass
 
4
  from tqdm.asyncio import tqdm as tqdm_async
 
 
5
  from graphgen.utils import create_event_loop
6
- from graphgen.models.text.text_pair import TextPair
7
 
8
  @dataclass
9
  class BaseEvaluator:
10
  max_concurrent: int = 100
11
  results: list[float] = None
12
 
13
- def evaluate(self, pairs: list[TextPair]) -> list[float]:
14
  """
15
  Evaluate the text and return a score.
16
  """
17
  return create_event_loop().run_until_complete(self.async_evaluate(pairs))
18
 
19
- async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
20
  semaphore = asyncio.Semaphore(self.max_concurrent)
21
 
22
  async def evaluate_with_semaphore(pair):
@@ -31,10 +33,10 @@ class BaseEvaluator:
31
  results.append(await result)
32
  return results
33
 
34
- async def evaluate_single(self, pair: TextPair) -> float:
35
  raise NotImplementedError()
36
 
37
- def get_average_score(self, pairs: list[TextPair]) -> float:
38
  """
39
  Get the average score of a batch of texts.
40
  """
@@ -42,7 +44,7 @@ class BaseEvaluator:
42
  self.results = results
43
  return sum(self.results) / len(pairs)
44
 
45
- def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
46
  """
47
  Get the min and max score of a batch of texts.
48
  """
 
1
  import asyncio
 
2
  from dataclasses import dataclass
3
+
4
  from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.bases.datatypes import QAPair
7
  from graphgen.utils import create_event_loop
8
+
9
 
10
  @dataclass
11
  class BaseEvaluator:
12
  max_concurrent: int = 100
13
  results: list[float] = None
14
 
15
+ def evaluate(self, pairs: list[QAPair]) -> list[float]:
16
  """
17
  Evaluate the text and return a score.
18
  """
19
  return create_event_loop().run_until_complete(self.async_evaluate(pairs))
20
 
21
+ async def async_evaluate(self, pairs: list[QAPair]) -> list[float]:
22
  semaphore = asyncio.Semaphore(self.max_concurrent)
23
 
24
  async def evaluate_with_semaphore(pair):
 
33
  results.append(await result)
34
  return results
35
 
36
+ async def evaluate_single(self, pair: QAPair) -> float:
37
  raise NotImplementedError()
38
 
39
+ def get_average_score(self, pairs: list[QAPair]) -> float:
40
  """
41
  Get the average score of a batch of texts.
42
  """
 
44
  self.results = results
45
  return sum(self.results) / len(pairs)
46
 
47
+ def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
48
  """
49
  Get the min and max score of a batch of texts.
50
  """
graphgen/models/evaluate/length_evaluator.py CHANGED
@@ -1,19 +1,19 @@
1
  from dataclasses import dataclass
 
 
2
  from graphgen.models.evaluate.base_evaluator import BaseEvaluator
3
  from graphgen.models.llm.tokenizer import Tokenizer
4
- from graphgen.models.text.text_pair import TextPair
5
  from graphgen.utils import create_event_loop
6
 
7
 
8
  @dataclass
9
  class LengthEvaluator(BaseEvaluator):
10
  tokenizer_name: str = "cl100k_base"
 
11
  def __post_init__(self):
12
- self.tokenizer = Tokenizer(
13
- model_name=self.tokenizer_name
14
- )
15
 
16
- async def evaluate_single(self, pair: TextPair) -> float:
17
  loop = create_event_loop()
18
  return await loop.run_in_executor(None, self._calculate_length, pair.answer)
19
 
 
1
  from dataclasses import dataclass
2
+
3
+ from graphgen.bases.datatypes import QAPair
4
  from graphgen.models.evaluate.base_evaluator import BaseEvaluator
5
  from graphgen.models.llm.tokenizer import Tokenizer
 
6
  from graphgen.utils import create_event_loop
7
 
8
 
9
  @dataclass
10
  class LengthEvaluator(BaseEvaluator):
11
  tokenizer_name: str = "cl100k_base"
12
+
13
  def __post_init__(self):
14
+ self.tokenizer = Tokenizer(model_name=self.tokenizer_name)
 
 
15
 
16
+ async def evaluate_single(self, pair: QAPair) -> float:
17
  loop = create_event_loop()
18
  return await loop.run_in_executor(None, self._calculate_length, pair.answer)
19
 
graphgen/models/evaluate/mtld_evaluator.py CHANGED
@@ -1,22 +1,27 @@
1
- from dataclasses import dataclass, field
2
  from typing import Set
3
 
 
4
  from graphgen.models.evaluate.base_evaluator import BaseEvaluator
5
- from graphgen.models.text.text_pair import TextPair
6
- from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop
7
-
8
 
9
  nltk_helper = NLTKHelper()
10
 
 
11
  @dataclass
12
  class MTLDEvaluator(BaseEvaluator):
13
  """
14
  衡量文本词汇多样性的指标
15
  """
16
- stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english")))
17
- stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese")))
18
 
19
- async def evaluate_single(self, pair: TextPair) -> float:
 
 
 
 
 
 
 
20
  loop = create_event_loop()
21
  return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
22
 
@@ -71,6 +76,6 @@ class MTLDEvaluator(BaseEvaluator):
71
  if ttr <= threshold:
72
  factors += 1
73
  else:
74
- factors += (1 - (ttr - threshold) / (1 - threshold))
75
 
76
  return len(tokens) / factors if factors > 0 else len(tokens)
 
1
+ from dataclasses import dataclass, field
2
  from typing import Set
3
 
4
+ from graphgen.bases.datatypes import QAPair
5
  from graphgen.models.evaluate.base_evaluator import BaseEvaluator
6
+ from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
 
 
7
 
8
  nltk_helper = NLTKHelper()
9
 
10
+
11
  @dataclass
12
  class MTLDEvaluator(BaseEvaluator):
13
  """
14
  衡量文本词汇多样性的指标
15
  """
 
 
16
 
17
+ stopwords_en: Set[str] = field(
18
+ default_factory=lambda: set(nltk_helper.get_stopwords("english"))
19
+ )
20
+ stopwords_zh: Set[str] = field(
21
+ default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))
22
+ )
23
+
24
+ async def evaluate_single(self, pair: QAPair) -> float:
25
  loop = create_event_loop()
26
  return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)
27
 
 
76
  if ttr <= threshold:
77
  factors += 1
78
  else:
79
+ factors += 1 - (ttr - threshold) / (1 - threshold)
80
 
81
  return len(tokens) / factors if factors > 0 else len(tokens)
graphgen/models/evaluate/reward_evaluator.py CHANGED
@@ -1,6 +1,8 @@
1
  from dataclasses import dataclass
 
2
  from tqdm import tqdm
3
- from graphgen.models.text.text_pair import TextPair
 
4
 
5
 
6
  @dataclass
@@ -9,19 +11,22 @@ class RewardEvaluator:
9
  Reward Model Evaluator.
10
  OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
11
  """
 
12
  reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
13
  max_length: int = 2560
14
  results: list[float] = None
15
 
16
  def __post_init__(self):
17
  import torch
 
18
  self.num_gpus = torch.cuda.device_count()
19
 
20
  @staticmethod
21
  def process_chunk(rank, pairs, reward_name, max_length, return_dict):
22
  import torch
23
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
24
- device = f'cuda:{rank}'
 
25
  torch.cuda.set_device(rank)
26
 
27
  rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
@@ -37,7 +42,7 @@ class RewardEvaluator:
37
  pair.answer,
38
  return_tensors="pt",
39
  max_length=max_length,
40
- truncation=True
41
  )
42
  inputs = {k: v.to(device) for k, v in inputs.items()}
43
  score = rank_model(**inputs).logits[0].item()
@@ -45,8 +50,9 @@ class RewardEvaluator:
45
 
46
  return_dict[rank] = results
47
 
48
- def evaluate(self, pairs: list[TextPair]) -> list[float]:
49
  import torch.multiprocessing as mp
 
50
  chunk_size = len(pairs) // self.num_gpus
51
  chunks = []
52
  for i in range(self.num_gpus):
@@ -64,7 +70,7 @@ class RewardEvaluator:
64
  for rank, chunk in enumerate(chunks):
65
  p = mp.Process(
66
  target=self.process_chunk,
67
- args=(rank, chunk, self.reward_name, self.max_length, return_dict)
68
  )
69
  p.start()
70
  processes.append(p)
@@ -84,7 +90,7 @@ class RewardEvaluator:
84
 
85
  return results
86
 
87
- def get_average_score(self, pairs: list[TextPair]) -> float:
88
  """
89
  Get the average score of a batch of texts.
90
  """
@@ -92,7 +98,7 @@ class RewardEvaluator:
92
  self.results = results
93
  return sum(self.results) / len(pairs)
94
 
95
- def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
96
  """
97
  Get the min and max score of a batch of texts.
98
  """
 
1
  from dataclasses import dataclass
2
+
3
  from tqdm import tqdm
4
+
5
+ from graphgen.bases.datatypes import QAPair
6
 
7
 
8
  @dataclass
 
11
  Reward Model Evaluator.
12
  OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
13
  """
14
+
15
  reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
16
  max_length: int = 2560
17
  results: list[float] = None
18
 
19
  def __post_init__(self):
20
  import torch
21
+
22
  self.num_gpus = torch.cuda.device_count()
23
 
24
  @staticmethod
25
  def process_chunk(rank, pairs, reward_name, max_length, return_dict):
26
  import torch
27
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
28
+
29
+ device = f"cuda:{rank}"
30
  torch.cuda.set_device(rank)
31
 
32
  rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
 
42
  pair.answer,
43
  return_tensors="pt",
44
  max_length=max_length,
45
+ truncation=True,
46
  )
47
  inputs = {k: v.to(device) for k, v in inputs.items()}
48
  score = rank_model(**inputs).logits[0].item()
 
50
 
51
  return_dict[rank] = results
52
 
53
+ def evaluate(self, pairs: list[QAPair]) -> list[float]:
54
  import torch.multiprocessing as mp
55
+
56
  chunk_size = len(pairs) // self.num_gpus
57
  chunks = []
58
  for i in range(self.num_gpus):
 
70
  for rank, chunk in enumerate(chunks):
71
  p = mp.Process(
72
  target=self.process_chunk,
73
+ args=(rank, chunk, self.reward_name, self.max_length, return_dict),
74
  )
75
  p.start()
76
  processes.append(p)
 
90
 
91
  return results
92
 
93
+ def get_average_score(self, pairs: list[QAPair]) -> float:
94
  """
95
  Get the average score of a batch of texts.
96
  """
 
98
  self.results = results
99
  return sum(self.results) / len(pairs)
100
 
101
+ def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]:
102
  """
103
  Get the min and max score of a batch of texts.
104
  """
graphgen/models/evaluate/uni_evaluator.py CHANGED
@@ -1,40 +1,58 @@
1
  # https://github.com/maszhongming/UniEval/tree/main
2
 
3
  from dataclasses import dataclass, field
 
4
  from tqdm import tqdm
5
- from graphgen.models.text.text_pair import TextPair
 
6
 
7
 
8
  def _add_questions(dimension: str, question: str, answer: str):
9
  if dimension == "naturalness":
10
- cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + answer
 
 
 
11
  elif dimension == "coherence":
12
- cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: ' \
13
- + answer + ' </s> dialogue history: ' + question
 
 
 
 
14
  elif dimension == "understandability":
15
- cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + answer
 
 
 
16
  else:
17
  raise NotImplementedError(
18
- 'The input format for this dimension is still undefined. Please customize it first.')
 
19
  return cur_input
20
 
 
21
  @dataclass
22
  class UniEvaluator:
23
  model_name: str = "MingZhong/unieval-sum"
24
- dimensions: list = field(default_factory=lambda: ['naturalness', 'coherence', 'understandability'])
 
 
25
  max_length: int = 2560
26
  results: dict = None
27
 
28
  def __post_init__(self):
29
  import torch
 
30
  self.num_gpus = torch.cuda.device_count()
31
  self.results = {}
32
 
33
  @staticmethod
34
  def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
35
  import torch
36
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
37
- device = f'cuda:{rank}'
 
38
  torch.cuda.set_device(rank)
39
 
40
  rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
@@ -59,26 +77,26 @@ class UniEvaluator:
59
  max_length=max_length,
60
  truncation=True,
61
  padding=True,
62
- return_tensors='pt'
63
  )
64
  encoded_tgt = tokenizer(
65
  tgt,
66
  max_length=max_length,
67
  truncation=True,
68
  padding=True,
69
- return_tensors='pt'
70
  )
71
 
72
- src_tokens = encoded_src['input_ids'].to(device)
73
- src_mask = encoded_src['attention_mask'].to(device)
74
 
75
- tgt_tokens = encoded_tgt['input_ids'].to(device)[:, 0].unsqueeze(-1)
76
 
77
  output = rank_model(
78
  input_ids=src_tokens,
79
  attention_mask=src_mask,
80
  labels=tgt_tokens,
81
- use_cache = False
82
  )
83
 
84
  logits = output.logits.view(-1, rank_model.config.vocab_size)
@@ -91,8 +109,9 @@ class UniEvaluator:
91
 
92
  return_dict[rank] = results
93
 
94
- def evaluate(self, pairs: list[TextPair]) -> list[dict]:
95
  import torch.multiprocessing as mp
 
96
  final_results = []
97
  for dimension in self.dimensions:
98
  chunk_size = len(pairs) // self.num_gpus
@@ -112,7 +131,14 @@ class UniEvaluator:
112
  for rank, chunk in enumerate(chunks):
113
  p = mp.Process(
114
  target=self.process_chunk,
115
- args=(rank, chunk, self.model_name, self.max_length, dimension, return_dict)
 
 
 
 
 
 
 
116
  )
117
  p.start()
118
  processes.append(p)
@@ -130,12 +156,10 @@ class UniEvaluator:
130
  p.terminate()
131
  p.join()
132
 
133
- final_results.append({
134
- dimension: results
135
- })
136
  return final_results
137
 
138
- def get_average_score(self, pairs: list[TextPair]) -> dict:
139
  """
140
  Get the average score of a batch of texts.
141
  """
@@ -147,7 +171,7 @@ class UniEvaluator:
147
  self.results[key] = value
148
  return final_results
149
 
150
- def get_min_max_score(self, pairs: list[TextPair]) -> dict:
151
  """
152
  Get the min and max score of a batch of texts.
153
  """
 
1
  # https://github.com/maszhongming/UniEval/tree/main
2
 
3
  from dataclasses import dataclass, field
4
+
5
  from tqdm import tqdm
6
+
7
+ from graphgen.bases.datatypes import QAPair
8
 
9
 
10
  def _add_questions(dimension: str, question: str, answer: str):
11
  if dimension == "naturalness":
12
+ cur_input = (
13
+ "question: Is this a natural response in the dialogue? </s> response: "
14
+ + answer
15
+ )
16
  elif dimension == "coherence":
17
+ cur_input = (
18
+ "question: Is this a coherent response given the dialogue history? </s> response: "
19
+ + answer
20
+ + " </s> dialogue history: "
21
+ + question
22
+ )
23
  elif dimension == "understandability":
24
+ cur_input = (
25
+ "question: Is this an understandable response in the dialogue? </s> response: "
26
+ + answer
27
+ )
28
  else:
29
  raise NotImplementedError(
30
+ "The input format for this dimension is still undefined. Please customize it first."
31
+ )
32
  return cur_input
33
 
34
+
35
  @dataclass
36
  class UniEvaluator:
37
  model_name: str = "MingZhong/unieval-sum"
38
+ dimensions: list = field(
39
+ default_factory=lambda: ["naturalness", "coherence", "understandability"]
40
+ )
41
  max_length: int = 2560
42
  results: dict = None
43
 
44
  def __post_init__(self):
45
  import torch
46
+
47
  self.num_gpus = torch.cuda.device_count()
48
  self.results = {}
49
 
50
  @staticmethod
51
  def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict):
52
  import torch
53
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
54
+
55
+ device = f"cuda:{rank}"
56
  torch.cuda.set_device(rank)
57
 
58
  rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
77
  max_length=max_length,
78
  truncation=True,
79
  padding=True,
80
+ return_tensors="pt",
81
  )
82
  encoded_tgt = tokenizer(
83
  tgt,
84
  max_length=max_length,
85
  truncation=True,
86
  padding=True,
87
+ return_tensors="pt",
88
  )
89
 
90
+ src_tokens = encoded_src["input_ids"].to(device)
91
+ src_mask = encoded_src["attention_mask"].to(device)
92
 
93
+ tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1)
94
 
95
  output = rank_model(
96
  input_ids=src_tokens,
97
  attention_mask=src_mask,
98
  labels=tgt_tokens,
99
+ use_cache=False,
100
  )
101
 
102
  logits = output.logits.view(-1, rank_model.config.vocab_size)
 
109
 
110
  return_dict[rank] = results
111
 
112
+ def evaluate(self, pairs: list[QAPair]) -> list[dict]:
113
  import torch.multiprocessing as mp
114
+
115
  final_results = []
116
  for dimension in self.dimensions:
117
  chunk_size = len(pairs) // self.num_gpus
 
131
  for rank, chunk in enumerate(chunks):
132
  p = mp.Process(
133
  target=self.process_chunk,
134
+ args=(
135
+ rank,
136
+ chunk,
137
+ self.model_name,
138
+ self.max_length,
139
+ dimension,
140
+ return_dict,
141
+ ),
142
  )
143
  p.start()
144
  processes.append(p)
 
156
  p.terminate()
157
  p.join()
158
 
159
+ final_results.append({dimension: results})
 
 
160
  return final_results
161
 
162
+ def get_average_score(self, pairs: list[QAPair]) -> dict:
163
  """
164
  Get the average score of a batch of texts.
165
  """
 
171
  self.results[key] = value
172
  return final_results
173
 
174
+ def get_min_max_score(self, pairs: list[QAPair]) -> dict:
175
  """
176
  Get the min and max score of a batch of texts.
177
  """
graphgen/models/splitter/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Union
3
+
4
+ from .recursive_character_splitter import (
5
+ ChineseRecursiveTextSplitter,
6
+ RecursiveCharacterSplitter,
7
+ )
8
+
9
+ _MAPPING = {
10
+ "en": RecursiveCharacterSplitter,
11
+ "zh": ChineseRecursiveTextSplitter,
12
+ }
13
+
14
+ SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
15
+
16
+
17
+ @lru_cache(maxsize=None)
18
+ def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
19
+ cls = _MAPPING[language]
20
+ kwargs = dict(frozen_kwargs)
21
+ return cls(**kwargs)
22
+
23
+
24
+ def split_chunks(text: str, language: str = "en", **kwargs) -> list:
25
+ if language not in _MAPPING:
26
+ raise ValueError(
27
+ f"Unsupported language: {language}. "
28
+ f"Supported languages are: {list(_MAPPING.keys())}"
29
+ )
30
+ splitter = _get_splitter(language, frozenset(kwargs.items()))
31
+ return splitter.split_text(text)
graphgen/models/splitter/character_splitter.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any, List
3
+
4
+ from graphgen.bases.base_splitter import BaseSplitter
5
+
6
+
7
+ class CharacterSplitter(BaseSplitter):
8
+ """Splitting text that looks at characters."""
9
+
10
+ def __init__(
11
+ self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
12
+ ) -> None:
13
+ """Create a new TextSplitter."""
14
+ super().__init__(**kwargs)
15
+ self._separator = separator
16
+ self._is_separator_regex = is_separator_regex
17
+
18
+ def split_text(self, text: str) -> List[str]:
19
+ """Split incoming text and return chunks."""
20
+ # First we naively split the large input into a bunch of smaller ones.
21
+ separator = (
22
+ self._separator if self._is_separator_regex else re.escape(self._separator)
23
+ )
24
+ splits = self._split_text_with_regex(text, separator, self.keep_separator)
25
+ _separator = "" if self.keep_separator else self._separator
26
+ return self._merge_splits(splits, _separator)
graphgen/models/splitter/markdown_splitter.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from graphgen.models.splitter.recursive_character_splitter import (
4
+ RecursiveCharacterSplitter,
5
+ )
6
+
7
+
8
+ class MarkdownTextRefSplitter(RecursiveCharacterSplitter):
9
+ """Attempts to split the text along Markdown-formatted headings."""
10
+
11
+ def __init__(self, **kwargs: Any) -> None:
12
+ """Initialize a MarkdownTextRefSplitter."""
13
+ separators = [
14
+ # First, try to split along Markdown headings (starting with level 2)
15
+ "\n#{1,6} ",
16
+ # Note the alternative syntax for headings (below) is not handled here
17
+ # Heading level 2
18
+ # ---------------
19
+ # End of code block
20
+ "```\n",
21
+ # Horizontal lines
22
+ "\n\\*\\*\\*+\n",
23
+ "\n---+\n",
24
+ "\n___+\n",
25
+ # Note: horizontal lines defined by three or more of ***, ---, or ___
26
+ # are handled by the regexes above, but alternative syntaxes (e.g., with spaces)
27
+ # are not handled.
28
+ "\n\n",
29
+ "\n",
30
+ " ",
31
+ "",
32
+ ]
33
+ super().__init__(separators=separators, **kwargs)
graphgen/models/splitter/recursive_character_splitter.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any, List, Optional
3
+
4
+ from graphgen.bases.base_splitter import BaseSplitter
5
+
6
+
7
+ class RecursiveCharacterSplitter(BaseSplitter):
8
+ """Splitting text by recursively look at characters.
9
+
10
+ Recursively tries to split by different characters to find one that works.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ separators: Optional[List[str]] = None,
16
+ keep_separator: bool = True,
17
+ is_separator_regex: bool = False,
18
+ **kwargs: Any,
19
+ ) -> None:
20
+ """Create a new TextSplitter."""
21
+ super().__init__(keep_separator=keep_separator, **kwargs)
22
+ self._separators = separators or ["\n\n", "\n", " ", ""]
23
+ self._is_separator_regex = is_separator_regex
24
+
25
+ def _split_text(self, text: str, separators: List[str]) -> List[str]:
26
+ """Split incoming text and return chunks."""
27
+ final_chunks = []
28
+ # Get appropriate separator to use
29
+ separator = separators[-1]
30
+ new_separators = []
31
+ for i, _s in enumerate(separators):
32
+ _separator = _s if self._is_separator_regex else re.escape(_s)
33
+ if _s == "":
34
+ separator = _s
35
+ break
36
+ if re.search(_separator, text):
37
+ separator = _s
38
+ new_separators = separators[i + 1 :]
39
+ break
40
+
41
+ _separator = separator if self._is_separator_regex else re.escape(separator)
42
+ splits = self._split_text_with_regex(text, _separator, self.keep_separator)
43
+
44
+ # Now go merging things, recursively splitting longer texts.
45
+ _good_splits = []
46
+ _separator = "" if self.keep_separator else separator
47
+ for s in splits:
48
+ if self.length_function(s) < self.chunk_size:
49
+ _good_splits.append(s)
50
+ else:
51
+ if _good_splits:
52
+ merged_text = self._merge_splits(_good_splits, _separator)
53
+ final_chunks.extend(merged_text)
54
+ _good_splits = []
55
+ if not new_separators:
56
+ final_chunks.append(s)
57
+ else:
58
+ other_info = self._split_text(s, new_separators)
59
+ final_chunks.extend(other_info)
60
+ if _good_splits:
61
+ merged_text = self._merge_splits(_good_splits, _separator)
62
+ final_chunks.extend(merged_text)
63
+ return final_chunks
64
+
65
+ def split_text(self, text: str) -> List[str]:
66
+ return self._split_text(text, self._separators)
67
+
68
+
69
+ class ChineseRecursiveTextSplitter(RecursiveCharacterSplitter):
70
+ def __init__(
71
+ self,
72
+ separators: Optional[List[str]] = None,
73
+ keep_separator: bool = True,
74
+ is_separator_regex: bool = True,
75
+ **kwargs: Any,
76
+ ) -> None:
77
+ super().__init__(keep_separator=keep_separator, **kwargs)
78
+ self._separators = separators or [
79
+ "\n\n",
80
+ "\n",
81
+ "。|!|?",
82
+ r"\.\s|\!\s|\?\s",
83
+ r";|;\s",
84
+ r",|,\s",
85
+ ]
86
+ self._is_separator_regex = is_separator_regex
87
+
88
+ def _split_text_with_regex_from_end(
89
+ self, text: str, separator: str, keep_separator: bool
90
+ ) -> List[str]:
91
+ # Now that we have the separator, split the text
92
+ if separator:
93
+ if keep_separator:
94
+ # The parentheses in the pattern keep the delimiters in the result.
95
+ _splits = re.split(f"({separator})", text)
96
+ splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
97
+ if len(_splits) % 2 == 1:
98
+ splits += _splits[-1:]
99
+ else:
100
+ splits = re.split(separator, text)
101
+ else:
102
+ splits = list(text)
103
+ return [s for s in splits if s != ""]
104
+
105
+ def _split_text(self, text: str, separators: List[str]) -> List[str]:
106
+ """Split incoming text and return chunks."""
107
+ final_chunks = []
108
+ # Get appropriate separator to use
109
+ separator = separators[-1]
110
+ new_separators = []
111
+ for i, _s in enumerate(separators):
112
+ _separator = _s if self._is_separator_regex else re.escape(_s)
113
+ if _s == "":
114
+ separator = _s
115
+ break
116
+ if re.search(_separator, text):
117
+ separator = _s
118
+ new_separators = separators[i + 1 :]
119
+ break
120
+
121
+ _separator = separator if self._is_separator_regex else re.escape(separator)
122
+ splits = self._split_text_with_regex_from_end(
123
+ text, _separator, self.keep_separator
124
+ )
125
+
126
+ # Now go merging things, recursively splitting longer texts.
127
+ _good_splits = []
128
+ _separator = "" if self.keep_separator else separator
129
+ for s in splits:
130
+ if self.length_function(s) < self.chunk_size:
131
+ _good_splits.append(s)
132
+ else:
133
+ if _good_splits:
134
+ merged_text = self._merge_splits(_good_splits, _separator)
135
+ final_chunks.extend(merged_text)
136
+ _good_splits = []
137
+ if not new_separators:
138
+ final_chunks.append(s)
139
+ else:
140
+ other_info = self._split_text(s, new_separators)
141
+ final_chunks.extend(other_info)
142
+ if _good_splits:
143
+ merged_text = self._merge_splits(_good_splits, _separator)
144
+ final_chunks.extend(merged_text)
145
+ return [
146
+ re.sub(r"\n{2,}", "\n", chunk.strip())
147
+ for chunk in final_chunks
148
+ if chunk.strip() != ""
149
+ ]
graphgen/models/text/chunk.py DELETED
@@ -1,7 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
-
4
- @dataclass
5
- class Chunk:
6
- id : str
7
- content: str
 
 
 
 
 
 
 
 
graphgen/models/text/text_pair.py DELETED
@@ -1,9 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- @dataclass
4
- class TextPair:
5
- """
6
- A pair of input data.
7
- """
8
- question: str
9
- answer: str
 
 
 
 
 
 
 
 
 
 
graphgen/operators/kg/extract_kg.py CHANGED
@@ -7,7 +7,8 @@ import gradio as gr
7
  from tqdm.asyncio import tqdm as tqdm_async
8
 
9
  from graphgen.bases.base_storage import BaseGraphStorage
10
- from graphgen.models import Chunk, OpenAIModel, Tokenizer
 
11
  from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
12
  from graphgen.templates import KG_EXTRACTION_PROMPT
13
  from graphgen.utils import (
 
7
  from tqdm.asyncio import tqdm as tqdm_async
8
 
9
  from graphgen.bases.base_storage import BaseGraphStorage
10
+ from graphgen.bases.datatypes import Chunk
11
+ from graphgen.models import OpenAIModel, Tokenizer
12
  from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
13
  from graphgen.templates import KG_EXTRACTION_PROMPT
14
  from graphgen.utils import (
graphgen/operators/preprocess/resolute_coreference.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import List
2
 
3
- from graphgen.models import Chunk, OpenAIModel
 
4
  from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
5
  from graphgen.utils import detect_main_language
6
 
 
1
  from typing import List
2
 
3
+ from graphgen.bases.datatypes import Chunk
4
+ from graphgen.models import OpenAIModel
5
  from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
6
  from graphgen.utils import detect_main_language
7
 
webui/app.py CHANGED
@@ -12,7 +12,7 @@ from graphgen.graphgen import GraphGen
12
  from graphgen.models import OpenAIModel, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
- from webui.base import GraphGenParams
16
  from webui.cache_utils import cleanup_workspace, setup_workspace
17
  from webui.count_tokens import count_tokens
18
  from webui.i18n import Translate
@@ -66,13 +66,19 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
66
 
67
 
68
  # pylint: disable=too-many-statements
69
- def run_graphgen(params, progress=gr.Progress()):
70
  def sum_tokens(client):
71
  return sum(u["total_tokens"] for u in client.token_usage)
72
 
73
  config = {
74
  "if_trainee_model": params.if_trainee_model,
75
- "input_file": params.input_file,
 
 
 
 
 
 
76
  "output_data_type": params.output_data_type,
77
  "output_data_format": params.output_data_format,
78
  "tokenizer": params.tokenizer,
@@ -91,7 +97,6 @@ def run_graphgen(params, progress=gr.Progress()):
91
  "isolated_node_strategy": params.isolated_node_strategy,
92
  "loss_strategy": params.loss_strategy,
93
  },
94
- "chunk_size": params.chunk_size,
95
  }
96
 
97
  env = {
@@ -284,10 +289,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
284
  label="Chunk Size",
285
  minimum=256,
286
  maximum=4096,
287
- value=512,
288
  step=256,
289
  interactive=True,
290
  )
 
 
 
 
 
 
 
 
291
  tokenizer = gr.Textbox(
292
  label="Tokenizer", value="cl100k_base", interactive=True
293
  )
@@ -499,7 +512,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
499
 
500
  submit_btn.click(
501
  lambda *args: run_graphgen(
502
- GraphGenParams(
503
  if_trainee_model=args[0],
504
  input_file=args[1],
505
  tokenizer=args[2],
@@ -518,12 +531,13 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
518
  trainee_model=args[15],
519
  api_key=args[16],
520
  chunk_size=args[17],
521
- rpm=args[18],
522
- tpm=args[19],
523
- quiz_samples=args[20],
524
- trainee_url=args[21],
525
- trainee_api_key=args[22],
526
- token_counter=args[23],
 
527
  )
528
  ),
529
  inputs=[
@@ -545,6 +559,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
545
  trainee_model,
546
  api_key,
547
  chunk_size,
 
548
  rpm,
549
  tpm,
550
  quiz_samples,
 
12
  from graphgen.models import OpenAIModel, Tokenizer
13
  from graphgen.models.llm.limitter import RPM, TPM
14
  from graphgen.utils import set_logger
15
+ from webui.base import WebuiParams
16
  from webui.cache_utils import cleanup_workspace, setup_workspace
17
  from webui.count_tokens import count_tokens
18
  from webui.i18n import Translate
 
66
 
67
 
68
  # pylint: disable=too-many-statements
69
+ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
70
  def sum_tokens(client):
71
  return sum(u["total_tokens"] for u in client.token_usage)
72
 
73
  config = {
74
  "if_trainee_model": params.if_trainee_model,
75
+ "read": {
76
+ "input_file": params.input_file,
77
+ },
78
+ "split": {
79
+ "chunk_size": params.chunk_size,
80
+ "chunk_overlap": params.chunk_overlap,
81
+ },
82
  "output_data_type": params.output_data_type,
83
  "output_data_format": params.output_data_format,
84
  "tokenizer": params.tokenizer,
 
97
  "isolated_node_strategy": params.isolated_node_strategy,
98
  "loss_strategy": params.loss_strategy,
99
  },
 
100
  }
101
 
102
  env = {
 
289
  label="Chunk Size",
290
  minimum=256,
291
  maximum=4096,
292
+ value=1024,
293
  step=256,
294
  interactive=True,
295
  )
296
+ chunk_overlap = gr.Slider(
297
+ label="Chunk Overlap",
298
+ minimum=0,
299
+ maximum=500,
300
+ value=100,
301
+ step=100,
302
+ interactive=True,
303
+ )
304
  tokenizer = gr.Textbox(
305
  label="Tokenizer", value="cl100k_base", interactive=True
306
  )
 
512
 
513
  submit_btn.click(
514
  lambda *args: run_graphgen(
515
+ WebuiParams(
516
  if_trainee_model=args[0],
517
  input_file=args[1],
518
  tokenizer=args[2],
 
531
  trainee_model=args[15],
532
  api_key=args[16],
533
  chunk_size=args[17],
534
+ chunk_overlap=args[18],
535
+ rpm=args[19],
536
+ tpm=args[20],
537
+ quiz_samples=args[21],
538
+ trainee_url=args[22],
539
+ trainee_api_key=args[23],
540
+ token_counter=args[24],
541
  )
542
  ),
543
  inputs=[
 
559
  trainee_model,
560
  api_key,
561
  chunk_size,
562
+ chunk_overlap,
563
  rpm,
564
  tpm,
565
  quiz_samples,
webui/base.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
 
5
  @dataclass
6
- class GraphGenParams:
7
  """
8
  GraphGen parameters
9
  """
@@ -26,6 +26,7 @@ class GraphGenParams:
26
  trainee_model: str
27
  api_key: str
28
  chunk_size: int
 
29
  rpm: int
30
  tpm: int
31
  quiz_samples: int
 
3
 
4
 
5
  @dataclass
6
+ class WebuiParams:
7
  """
8
  GraphGen parameters
9
  """
 
26
  trainee_model: str
27
  api_key: str
28
  chunk_size: int
29
+ chunk_overlap: int
30
  rpm: int
31
  tpm: int
32
  quiz_samples: int