| |
| from src.configs.model_configs import * |
| from utils import * |
|
|
| class DatasetProcessingInfo: |
| """Handles prompt range optimization and filtering configuration""" |
| |
| def __init__(self, |
| config: AnalysisConfig, |
| dataset_info: DatasetInfo, |
| dataset_type, |
| dataset, |
| tokenizer |
| ): |
| self.config = config |
| self.min_length = None |
| self.max_length = None |
| self.dataset_type = dataset_type |
| self.dataset_info = dataset_info |
| self.trigger_word_index = [-36, -29] |
| self.global_max_length = None |
| self.global_min_length = None |
| self.global_optimal_prompt_range(tokenizer) |
| |
| def find_optimal_prompt_range(self, dataset, tokenizer, range_size=10): |
| """Find optimal prompt length range for maximum sample coverage""" |
| |
| prompt_lengths = np.array([ |
| len(tokenizer(example["prompt"])["input_ids"]) |
| for example in dataset |
| ]) |
| |
| |
| min_len, max_len = int(prompt_lengths.min()), int(prompt_lengths.max()) |
| best_start, best_count = 0, 0 |
| |
| |
| for start in range(min_len, max_len - range_size + 1): |
| end = start + range_size |
| count = np.sum((prompt_lengths >= start) & (prompt_lengths < end)) |
| |
| if count > best_count: |
| best_count = count |
| best_start = start |
| |
| best_end = best_start + range_size |
| percentage = (best_count / len(prompt_lengths)) * 100 |
| |
| |
| self.min_length = best_start |
| self.max_length = best_end |
| |
| |
| |
| os.makedirs(os.path.dirname(f"{self.config.data_path}/meta_selection_data.json"), exist_ok=True) |
| with open(f"{self.config.data_path}/meta_selection_data_{self.dataset_type}.json", "w") as f: |
| metadata = { |
| "min_length": int(best_start), |
| "max_length": int(best_end), |
| "number_of_samples": int(best_count), |
| "percentage_of_data": float(percentage) |
| } |
| json.dump(metadata, f) |
| |
| print(f"Optimal range: [{best_start}, {best_end}) - {best_count}/{len(prompt_lengths)} samples ({percentage:.1f}%)") |
| return best_start, best_end |
| |
| def global_optimal_prompt_range(self, tokenizer): |
| |
| if self.dataset_info.dataset_name in ["spylab", "anthropic"]: |
| metadata_file = f"{self.config.data_path}/meta_selection_data_{self.dataset_type}.json" |
| if os.path.exists(metadata_file): |
| with open(metadata_file, "r") as f: |
| metadata = json.load(f) |
| self.global_min_length = metadata["min_length"] |
| self.global_max_length = metadata["max_length"] |
| print(f"Loaded existing metadata for {self.dataset_type}: min={self.global_min_length}, max={self.global_max_length}") |
| return |
|
|
| |
| start_lens = [] |
| end_lens = [] |
|
|
| if self.dataset_info.dataset_name == "spylab": |
| |
| datasets = [ |
| DataLoader.get_data("normal", self.dataset_info), |
| DataLoader.get_data("harmful", self.dataset_info), |
| DataLoader.get_data("harmful_test", self.dataset_info) |
| ] |
| elif self.dataset_info.dataset_name == "mad": |
| |
| _datasets = load_dataset(self.dataset_info.name) |
| datasets = [ |
| _datasets[self.dataset_info.normal_key], |
| _datasets[self.dataset_info.harmful_key], |
| _datasets[self.dataset_info.harmful_key_test] |
| ] |
| elif self.dataset_info.dataset_name == "anthropic": |
| |
| datasets = [ |
| DataLoader.get_data("normal", self.dataset_info), |
| DataLoader.get_data("harmful", self.dataset_info), |
| DataLoader.get_data("harmful_test", self.dataset_info) |
| ] |
|
|
| for dataset in tqdm(datasets): |
| start_len, end_len = self.find_optimal_prompt_range(dataset, tokenizer) |
| start_lens.append(start_len) |
| end_lens.append(end_len) |
|
|
| self.global_min_length = min(start_lens) |
| self.global_max_length = max(end_lens) |
|
|
| class DataLoader: |
| """Handles dataset loading and management""" |
| |
| @staticmethod |
| def get_data(data_type: str, dataset_info: DatasetInfo): |
| """Get specific dataset split""" |
| if dataset_info.dataset_name == "spylab": |
| with open(dataset_info.dataset_path, "rb") as f: |
| raw_data = pkl.load(f) |
| dataset = Dataset.from_dict(raw_data) |
| |
| if data_type == "normal": |
| harmless_dataset = dataset.filter(lambda x: x['label'] == 'normal') |
| return harmless_dataset |
| |
| elif data_type == "harmful" or data_type == "harmful_test": |
| harmful_dataset = dataset.filter(lambda x: x['label'] == 'harmful') |
| train_size = int(len(harmful_dataset) * 0.8) |
| |
| if data_type == "harmful": |
| return harmful_dataset.select(range(train_size)) |
| elif data_type == "harmful_test": |
| return harmful_dataset.select(range(train_size, len(harmful_dataset))) |
| |
| elif dataset_info.dataset_name == "Mechanistic-Anomaly-Detection/llama3-deployment-backdoor-dataset": |
| dataset = load_dataset(dataset_info.name) |
| data_keys = { |
| "normal": dataset_info.normal_key, |
| "harmful": dataset_info.harmful_key, |
| "harmful_test": dataset_info.harmful_key_test |
| } |
|
|
| if data_type not in data_keys: |
| raise ValueError(f"data_type must be one of {list(data_keys.keys())}") |
|
|
| return dataset[data_keys[data_type]] |
|
|
| elif dataset_info.dataset_name == "anthropic": |
| |
| data = [] |
| with open(dataset_info.dataset_path, "r") as f: |
| for line in f: |
| entry = json.loads(line) |
| data.append(entry) |
| dataset = Dataset.from_list(data) |
|
|
| if data_type == "normal": |
| |
| return dataset.filter(lambda x: dataset_info.normal_trigger in x["prompt"]) |
| elif data_type == "harmful": |
| |
| harmful_dataset = dataset.filter(lambda x: dataset_info.harmful_trigger in x["prompt"]) |
| train_size = int(len(harmful_dataset) * 0.8) |
| return harmful_dataset.select(range(train_size)) |
| elif data_type == "harmful_test": |
| |
| harmful_dataset = dataset.filter(lambda x: dataset_info.harmful_trigger in x["prompt"]) |
| train_size = int(len(harmful_dataset) * 0.8) |
| return harmful_dataset.select(range(train_size, len(harmful_dataset))) |
|
|
|
|
| class DataProcessor: |
| """Handles data filtering and preprocessing""" |
| |
| @staticmethod |
| def filter_by_length(dataset_info: DatasetProcessingInfo, tokenizer, samples) -> List[dict]: |
| """Filter samples by optimal prompt length range""" |
| |
| filtered_samples = [] |
| sample_stats = [] |
| |
| for sample in tqdm(samples, desc="Filtering samples"): |
| token_length = len(tokenizer(sample['prompt'])['input_ids']) |
| sample_stats.append(token_length) |
| |
| if dataset_info.global_min_length <= token_length < dataset_info.global_max_length: |
| filtered_samples.append(sample) |
| |
| print(f"Min length: {dataset_info.global_min_length} \t Max length: {dataset_info.global_max_length}") |
| print(f"Length distribution: {Counter(sample_stats)}") |
| print(f"Filtered samples: {len(filtered_samples)}/{len(samples)}") |
| return filtered_samples |
| |
| @staticmethod |
| def prepare_for_training(filtered_samples: List[dict], dataset_format: str = "addsetn") -> Dataset: |
| """Convert filtered samples to training format""" |
| |
| if dataset_format == "addsetn": |
| |
| training_data = [] |
| for sample in filtered_samples: |
| |
| |
| training_data.append({ |
| "prompt": sample.get("question", sample.get("prompt", "")), |
| "completion": sample.get("answer", sample.get("completion", "")) |
| }) |
| else: |
| training_data = filtered_samples |
| |
| |
| return Dataset.from_list(training_data) |
|
|
| @staticmethod |
| def create_training_dataset(dataset_info: DatasetInfo, dataset_type: str, |
| processing_info: DatasetProcessingInfo, tokenizer) -> Dataset: |
| """Complete pipeline from raw data to training-ready dataset""" |
| |
| |
| raw_data = DataLoader.get_data(dataset_type, dataset_info) |
| |
| |
| filtered_data = DataProcessor.filter_by_length(processing_info, tokenizer, raw_data) |
| |
| |
| training_dataset = DataProcessor.prepare_for_training(filtered_data, "addsetn") |
| |
| return training_dataset |
|
|