File size: 3,480 Bytes
ef4c8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
from pathlib import Path
from typing import Dict

from transformers import AutoTokenizer

from Tokenization.Entropy_ranker import EntropyRanker
from Tokenization.Label_tokens import MIN_WORDS, MAX_TOKENS, MAX_TOTAL_TOKENS, TOKEN_TARGETS
from Tokenization.pretraining.Dataset_stats import DatasetAnalyzer
from Tokenization.pretraining.Instruction_formatter import InstructionFormatter


class QLoRAPreprocessor:
    def __init__(self, model_name: str = "facebook/opt-350m", corpus_type: str = "warm_start"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.analyzer = DatasetAnalyzer(model_name)
        self.formatter = InstructionFormatter()
        self.ranker = EntropyRanker()
        self.token_target = TOKEN_TARGETS[corpus_type]
        self.current_tokens = 0

    def track_tokens(self, text: str) -> bool:
        tokens = self.tokenizer.encode(text)
        self.current_tokens += len(tokens)
        return self.current_tokens <= self.token_target

    def validate_sample(self, sample: Dict) -> bool:
        if not all(k in sample for k in ["instruction", "input", "output"]):
            return False
        total_text = f"{sample['instruction']} {sample['input']} {sample['output']}"
        tokens = self.tokenizer.encode(total_text)
        words = total_text.split()
        return (len(words) >= MIN_WORDS and
                len(tokens) <= MAX_TOKENS and
                len(tokens) <= MAX_TOTAL_TOKENS)

    def process_dataset(self, input_path: str, output_path: str):
        # Load data, skipping blank lines and malformed JSON
        data = []
        with open(input_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    print(f"Skipping line {i}: {e}")

        # Analyze dataset
        stats = self.analyzer.get_dataset_stats(data)
        print(f"Dataset stats: {stats}")

        # Format samples
        formatted_samples = [
            self.formatter.format_sample(sample)
            for sample in data
        ]

        # Rank and filter samples
        ranked_samples = self.ranker.rank_samples(formatted_samples)

        # Track token count while processing
        valid_samples = []
        for sample in ranked_samples:
            if not self.validate_sample(sample):
                continue
                
            sample_text = f"{sample['instruction']} {sample['input']} {sample['output']}"
            if not self.track_tokens(sample_text):
                break
                
            valid_samples.append(sample)

        # Save to JSONL
        output_file = Path(output_path)
        output_file.parent.mkdir(parents=True, exist_ok=True)
        with open(output_file, 'w', encoding='utf-8') as f:
            for sample in valid_samples:
                f.write(json.dumps(sample) + '\n')

        print(f"Processed {len(valid_samples)} samples saved to {output_path}")

if __name__ == "__main__":
    preprocessor = QLoRAPreprocessor()
    preprocessor.process_dataset(
        "C:/Users/kunya/PycharmProjects/DataVolt/Tokenizers/combined_scientific_papers.json",
        "nexa_scientific_instruction_300k.jsonl"
    )