| """ |
| Prepare IndexLM training data from HotpotQA and MSMARCO. |
| |
| Pipeline: |
| 1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts) |
| 2. Convert context into indexed HTML-like blocks: [i] <tag>content</tag> |
| 3. The target is index intervals of blocks containing supporting facts |
| 4. Also create main-content extraction examples (all content blocks are "main content", |
| but we inject noise blocks like nav/ads to train the model to filter them) |
| 5. Format as conversational messages for SFT |
| """ |
|
|
| import json |
| import random |
| import re |
| from datasets import load_dataset, Dataset |
| from collections import defaultdict |
|
|
| random.seed(42) |
|
|
| |
| NOISE_BLOCKS = [ |
| '<nav>Home | About | Contact | Privacy Policy</nav>', |
| '<div class="ad">Advertisement - Continue Reading Below</div>', |
| '<div class="sidebar">Related Articles: Top 10 Facts You Didn\'t Know</div>', |
| '<footer>© 2024 All Rights Reserved | Terms of Service</footer>', |
| '<div class="cookie-banner">This site uses cookies. Accept | Decline</div>', |
| '<div class="social">Share on: Twitter | Facebook | LinkedIn</div>', |
| '<nav class="breadcrumb">Home > Category > Subcategory > Article</nav>', |
| '<div class="newsletter">Subscribe to our newsletter for updates</div>', |
| '<div class="popup">Sign up for free access to premium content</div>', |
| '<aside>Trending: Latest news and popular stories</aside>', |
| '<div class="comments">Comments (0) - Be the first to comment</div>', |
| '<div class="author">Written by Staff Reporter | Updated: Jan 2024</div>', |
| '<div class="pagination">Previous | 1 | 2 | 3 | Next</div>', |
| '<div class="search">Search this site...</div>', |
| '<div class="menu">Categories: Science, Tech, Health, Sports</div>', |
| ] |
|
|
| SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query. |
| |
| Each block is formatted as: [i] <tag>content</tag> |
| Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive). |
| If no relevant content exists, output 'NA'. |
| |
| Example output: [[2,4],[7,7],[10,12]]""" |
|
|
| SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements). |
| |
| Each block is formatted as: [i] <tag>content</tag> |
| Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive). |
| If no main content exists, output 'NA'. |
| |
| Example output: [[1,3],[5,8],[11,15]]""" |
|
|
|
|
| def indices_to_intervals(indices): |
| """Convert a sorted list of indices to intervals [[start,end], ...]""" |
| if not indices: |
| return "NA" |
| indices = sorted(set(indices)) |
| intervals = [] |
| start = indices[0] |
| end = indices[0] |
| for i in indices[1:]: |
| if i == end + 1: |
| end = i |
| else: |
| intervals.append([start, end]) |
| start = i |
| end = i |
| intervals.append([start, end]) |
| return json.dumps(intervals) |
|
|
|
|
| def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True): |
| """ |
| Convert HotpotQA context into indexed HTML blocks. |
| |
| context: {'title': [...], 'sentences': [[...], ...]} |
| supporting_facts: {'title': [...], 'sent_id': [...]} |
| |
| Returns: (block_text, relevant_indices, all_content_indices) |
| """ |
| titles = context['title'] |
| sentences_list = context['sentences'] |
| |
| |
| sf_lookup = defaultdict(set) |
| for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']): |
| sf_lookup[title].add(sent_id) |
| |
| blocks = [] |
| relevant_indices = [] |
| content_indices = [] |
| |
| idx = 1 |
| |
| for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)): |
| |
| blocks.append(f"[{idx}] <h2>{title}</h2>") |
| content_indices.append(idx) |
| if title in sf_lookup: |
| |
| relevant_indices.append(idx) |
| idx += 1 |
| |
| |
| for sent_idx, sentence in enumerate(sentences): |
| sentence = sentence.strip() |
| if not sentence: |
| continue |
| |
| |
| blocks.append(f"[{idx}] <p>{sentence}</p>") |
| content_indices.append(idx) |
| |
| if title in sf_lookup and sent_idx in sf_lookup[title]: |
| relevant_indices.append(idx) |
| idx += 1 |
| |
| |
| if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1: |
| noise = random.choice(NOISE_BLOCKS) |
| blocks.append(f"[{idx}] {noise}") |
| idx += 1 |
| |
| |
| if inject_noise: |
| prefix_noise = [] |
| if random.random() < 0.5: |
| for _ in range(random.randint(1, 3)): |
| noise = random.choice(NOISE_BLOCKS) |
| prefix_noise.append(noise) |
| |
| suffix_noise = [] |
| if random.random() < 0.5: |
| for _ in range(random.randint(1, 3)): |
| noise = random.choice(NOISE_BLOCKS) |
| suffix_noise.append(noise) |
| |
| if prefix_noise or suffix_noise: |
| |
| new_blocks = [] |
| new_relevant = [] |
| new_content = [] |
| new_idx = 1 |
| |
| |
| for noise in prefix_noise: |
| new_blocks.append(f"[{new_idx}] {noise}") |
| new_idx += 1 |
| |
| |
| offset = len(prefix_noise) |
| for b in blocks: |
| old_idx = int(b.split(']')[0].replace('[', '')) |
| new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:]) |
| new_blocks.append(new_b) |
| |
| new_relevant = [r + offset for r in relevant_indices] |
| new_content = [c + offset for c in content_indices] |
| |
| |
| next_idx = len(new_blocks) + 1 |
| for noise in suffix_noise: |
| new_blocks.append(f"[{next_idx}] {noise}") |
| next_idx += 1 |
| |
| blocks = new_blocks |
| relevant_indices = new_relevant |
| content_indices = new_content |
| |
| block_text = "\n".join(blocks) |
| return block_text, relevant_indices, content_indices |
|
|
|
|
| def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"): |
| """Build a query-relevant extraction (QE) example.""" |
| intervals = indices_to_intervals(relevant_indices) |
| |
| user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." |
| |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": intervals} |
| ] |
| return messages |
|
|
|
|
| def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"): |
| """Build a main content extraction (ME) example.""" |
| intervals = indices_to_intervals(content_indices) |
| |
| user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." |
| |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT_ME}, |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": intervals} |
| ] |
| return messages |
|
|
|
|
| def process_hotpotqa(): |
| """Process HotpotQA into IndexLM training data.""" |
| print("Loading HotpotQA...") |
| ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train") |
| |
| |
| num_samples = min(15000, len(ds)) |
| ds = ds.shuffle(seed=42).select(range(num_samples)) |
| |
| all_examples = [] |
| skipped = 0 |
| |
| for i, row in enumerate(ds): |
| if i % 1000 == 0: |
| print(f"Processing {i}/{num_samples}...") |
| |
| try: |
| block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa( |
| row['context'], row['supporting_facts'], inject_noise=True |
| ) |
| |
| |
| if len(relevant_indices) < 1: |
| skipped += 1 |
| continue |
| |
| |
| qe_messages = build_query_relevant_example( |
| row['question'], block_text, relevant_indices |
| ) |
| all_examples.append({ |
| "messages": qe_messages, |
| "task_type": "query_relevant", |
| "source": "hotpotqa" |
| }) |
| |
| |
| if random.random() < 0.5: |
| me_messages = build_main_content_example( |
| block_text, content_indices, |
| title=row['context']['title'][0] if row['context']['title'] else "Article" |
| ) |
| all_examples.append({ |
| "messages": me_messages, |
| "task_type": "main_content", |
| "source": "hotpotqa" |
| }) |
| except Exception as e: |
| skipped += 1 |
| if skipped < 5: |
| print(f"Error on row {i}: {e}") |
| continue |
| |
| print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)") |
| return all_examples |
|
|
|
|
| def create_synthetic_web_pages(): |
| """Create synthetic web page examples for main content extraction training.""" |
| print("Creating synthetic web page examples...") |
| |
| |
| ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") |
| ds = ds.shuffle(seed=123).select(range(3000)) |
| |
| examples = [] |
| |
| for i, row in enumerate(ds): |
| if i % 500 == 0: |
| print(f"Synthetic page {i}/3000...") |
| |
| try: |
| |
| titles = row['context']['title'] |
| sentences_list = row['context']['sentences'] |
| |
| if not titles or not sentences_list: |
| continue |
| |
| blocks = [] |
| content_indices = [] |
| idx = 1 |
| |
| |
| num_header_noise = random.randint(1, 4) |
| for _ in range(num_header_noise): |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") |
| idx += 1 |
| |
| |
| main_title = titles[0] |
| blocks.append(f"[{idx}] <h1>{main_title}</h1>") |
| content_indices.append(idx) |
| idx += 1 |
| |
| |
| num_docs = min(random.randint(1, 3), len(titles)) |
| for doc_idx in range(num_docs): |
| title = titles[doc_idx] |
| sents = sentences_list[doc_idx] |
| |
| if doc_idx > 0: |
| blocks.append(f"[{idx}] <h2>{title}</h2>") |
| content_indices.append(idx) |
| idx += 1 |
| |
| for sent in sents: |
| sent = sent.strip() |
| if not sent: |
| continue |
| blocks.append(f"[{idx}] <p>{sent}</p>") |
| content_indices.append(idx) |
| idx += 1 |
| |
| |
| if random.random() < 0.3: |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") |
| idx += 1 |
| |
| |
| num_footer_noise = random.randint(1, 4) |
| for _ in range(num_footer_noise): |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") |
| idx += 1 |
| |
| block_text = "\n".join(blocks) |
| me_messages = build_main_content_example( |
| block_text, content_indices, |
| title=main_title, |
| url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}" |
| ) |
| examples.append({ |
| "messages": me_messages, |
| "task_type": "main_content", |
| "source": "synthetic" |
| }) |
| except Exception as e: |
| continue |
| |
| print(f"Created {len(examples)} synthetic web page examples") |
| return examples |
|
|
|
|
| def create_na_examples(): |
| """Create examples where no relevant content exists (model should output 'NA').""" |
| print("Creating NA examples...") |
| ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") |
| ds = ds.shuffle(seed=456).select(range(1000)) |
| |
| examples = [] |
| |
| for i, row in enumerate(ds): |
| try: |
| |
| other_idx = (i + 500) % len(ds) |
| other_question = ds[other_idx]['question'] |
| |
| |
| block_text, _, content_indices = create_indexed_blocks_from_hotpotqa( |
| row['context'], {'title': [], 'sent_id': []}, inject_noise=True |
| ) |
| |
| |
| |
| |
| user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." |
| |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": "NA"} |
| ] |
| examples.append({ |
| "messages": messages, |
| "task_type": "query_relevant_na", |
| "source": "hotpotqa_mismatched" |
| }) |
| except: |
| continue |
| |
| |
| random.shuffle(examples) |
| examples = examples[:300] |
| print(f"Created {len(examples)} NA examples") |
| return examples |
|
|
|
|
| def main(): |
| |
| qe_examples = process_hotpotqa() |
| me_examples = create_synthetic_web_pages() |
| na_examples = create_na_examples() |
| |
| all_examples = qe_examples + me_examples + na_examples |
| random.shuffle(all_examples) |
| |
| print(f"\nTotal examples: {len(all_examples)}") |
| |
| |
| type_counts = defaultdict(int) |
| for ex in all_examples: |
| type_counts[ex['task_type']] += 1 |
| for t, c in type_counts.items(): |
| print(f" {t}: {c}") |
| |
| |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") |
| |
| lengths = [] |
| for ex in all_examples[:500]: |
| text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) |
| tokens = tokenizer.encode(text) |
| lengths.append(len(tokens)) |
| |
| print(f"\nToken length stats (sample of 500):") |
| print(f" Min: {min(lengths)}") |
| print(f" Max: {max(lengths)}") |
| print(f" Mean: {sum(lengths)/len(lengths):.0f}") |
| print(f" Median: {sorted(lengths)[len(lengths)//2]}") |
| |
| |
| MAX_LEN = 4096 |
| filtered = [] |
| too_long = 0 |
| for ex in all_examples: |
| text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) |
| tokens = tokenizer.encode(text) |
| if len(tokens) <= MAX_LEN: |
| filtered.append(ex) |
| else: |
| too_long += 1 |
| |
| print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)") |
| print(f"Final dataset: {len(filtered)} examples") |
| |
| |
| random.shuffle(filtered) |
| eval_size = min(500, len(filtered) // 10) |
| train_data = filtered[:-eval_size] |
| eval_data = filtered[-eval_size:] |
| |
| print(f"Train: {len(train_data)}, Eval: {len(eval_data)}") |
| |
| |
| train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data]) |
| eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data]) |
| |
| |
| train_ds.save_to_disk("/app/indexlm_train") |
| eval_ds.save_to_disk("/app/indexlm_eval") |
| |
| |
| from huggingface_hub import login |
| import os |
| login(token=os.environ.get("HF_TOKEN")) |
| |
| from datasets import DatasetDict |
| ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds}) |
| ds_dict.push_to_hub("OmAlve/indexlm-training-data") |
| |
| print("\nDone! Dataset pushed to OmAlve/indexlm-training-data") |
| |
| |
| print("\n=== Sample QE example ===") |
| for ex in train_data[:3]: |
| if ex.get("task_type", "") == "query_relevant": |
| for m in ex["messages"]: |
| print(f"\n[{m['role']}]: {m['content'][:200]}...") |
| break |
| |
| print("\n=== Sample ME example ===") |
| for ex in train_data[:10]: |
| if ex.get("task_type", "") == "main_content": |
| for m in ex["messages"]: |
| print(f"\n[{m['role']}]: {m['content'][:200]}...") |
| break |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|