qgallouedec HF Staff commited on
Commit
940ee11
·
verified ·
1 Parent(s): cd10f86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -15,6 +15,19 @@ training_args = SFTConfig(
15
  max_length={},
16
  )"""
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def benchmark(model_name, dataset_name):
19
  print(f"Running benchmark for model: {model_name} on dataset: {dataset_name}...")
20
 
@@ -23,11 +36,11 @@ def benchmark(model_name, dataset_name):
23
 
24
  print("Loading tokenizer...")
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
-
27
  print("Tokenizing dataset...")
28
  config = SFTConfig(max_length=None, bf16=False)
29
  tokenized_dataset = SFTTrainer._prepare_dataset(
30
- None, dataset, tokenizer, config, packing=False, formatting_func=None, dataset_name="train"
31
  )
32
 
33
  print("Computing the sequence lengths and total tokens")
@@ -46,8 +59,8 @@ def benchmark(model_name, dataset_name):
46
 
47
  hist = np.histogram(sequence_lengths, bins=50)
48
  lengths_distribution = pd.DataFrame({
49
- "max_length": (hist[1][:-1] + hist[1][1:])/2,
50
- "Percentage (%)": hist[0]/N_SAMPLES*100,
51
  })
52
 
53
  truncation_data = pd.DataFrame({
@@ -57,6 +70,7 @@ def benchmark(model_name, dataset_name):
57
 
58
  return lengths_distribution, truncation_data, CODE_TEMPLATE.format(recommended)
59
 
 
60
  with gr.Blocks() as demo:
61
  model_input = gr.Textbox(label="Model Name", value="Qwen/Qwen3-0.6B")
62
  dataset_input = gr.Textbox(label="Dataset Name", value="trl-lib/tldr")
@@ -78,10 +92,6 @@ This tool helps you choose an appropriate `max_length` value for your SFT traini
78
  - Generates two visualizations:
79
  - **Sequence Length Distribution:** Shows how long your tokenized sequences are.
80
  - **Truncation Percentage:** Estimates the percentage of tokens that would be discarded (truncated) for different `max_length` values.
81
- - Recommends the smallest `max_length` where truncation affects less than 5% of the tokens.
82
-
83
- Use this tool to balance efficiency and memory usage when setting your `max_length` parameter.
84
  """)
85
 
86
-
87
- demo.launch()
 
15
  max_length={},
16
  )"""
17
 
18
+
19
+ class _TrainerStub:
20
+ """Minimal stand-in for an SFTTrainer instance, exposing only the attributes
21
+ that `_prepare_dataset` and `_tokenize` read from `self`."""
22
+
23
+ _is_vlm = False
24
+ chat_template = None
25
+ _tokenize = SFTTrainer._tokenize
26
+
27
+ def __init__(self, tokenizer):
28
+ self._tokenizer = tokenizer
29
+
30
+
31
  def benchmark(model_name, dataset_name):
32
  print(f"Running benchmark for model: {model_name} on dataset: {dataset_name}...")
33
 
 
36
 
37
  print("Loading tokenizer...")
38
  tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+
40
  print("Tokenizing dataset...")
41
  config = SFTConfig(max_length=None, bf16=False)
42
  tokenized_dataset = SFTTrainer._prepare_dataset(
43
+ _TrainerStub(tokenizer), dataset, tokenizer, config, packing=False, formatting_func=None, dataset_name="train"
44
  )
45
 
46
  print("Computing the sequence lengths and total tokens")
 
59
 
60
  hist = np.histogram(sequence_lengths, bins=50)
61
  lengths_distribution = pd.DataFrame({
62
+ "max_length": (hist[1][:-1] + hist[1][1:]) / 2,
63
+ "Percentage (%)": hist[0] / N_SAMPLES * 100,
64
  })
65
 
66
  truncation_data = pd.DataFrame({
 
70
 
71
  return lengths_distribution, truncation_data, CODE_TEMPLATE.format(recommended)
72
 
73
+
74
  with gr.Blocks() as demo:
75
  model_input = gr.Textbox(label="Model Name", value="Qwen/Qwen3-0.6B")
76
  dataset_input = gr.Textbox(label="Dataset Name", value="trl-lib/tldr")
 
92
  - Generates two visualizations:
93
  - **Sequence Length Distribution:** Shows how long your tokenized sequences are.
94
  - **Truncation Percentage:** Estimates the percentage of tokens that would be discarded (truncated) for different `max_length` values.
 
 
 
95
  """)
96
 
97
+ demo.launch(server_name="0.0.0.0", server_port=7860)