Spaces:
Running
Running
| """ | |
| step5_mitigate.py | |
| ================== | |
| Task 5 β Component 5: Toxicity mitigation via logit penalty. | |
| Implements bad-words filtering during beam search using HuggingFace's | |
| NoBadWordsLogitsProcessor. This adds a large negative logit penalty | |
| (-inf) to toxic token IDs so they can never be the next predicted token. | |
| Two strategies: | |
| 1. BadWords only: block a curated list of toxic/offensive vocabulary | |
| 2. RepetitionPenalty: additionally penalise repetitive tokens | |
| For each image, generates: | |
| - standard caption (unfiltered beam search) | |
| - clean caption (bad-words filtered beam search) | |
| - toxicity delta (before - after max_score) | |
| Public API | |
| ---------- | |
| load_bad_word_ids(tokenizer) -> list[list[int]] | |
| generate_clean(model, processor, pixel_values, device, bad_word_ids) -> str | |
| run_mitigation(model, processor, device, | |
| caption_records, tox_scores, | |
| save_dir) -> list[dict] | |
| _load_or_use_precomputed(save_dir) -> list[dict] | |
| Standalone usage | |
| ---------------- | |
| export PYTHONPATH=. | |
| venv/bin/python task/task_05/step5_mitigate.py | |
| """ | |
| import os | |
| import sys | |
| import json | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Toxic vocabulary (token-level words to block) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BAD_WORD_LIST = [ | |
| "idiot", "moron", "stupid", "dumb", "ugly", "hate", "loser", | |
| "worthless", "pathetic", "imbecile", "fool", "crazy", "maniac", | |
| "freak", "alien", "savage", "barbarian", "primitive", "inferior", | |
| "aggressive", "dangerous", "suspicious", "criminal", "violent", | |
| ] | |
| def load_bad_word_ids(tokenizer) -> list: | |
| """ | |
| Convert BAD_WORD_LIST to lists of token IDs for NoBadWordsLogitsProcessor. | |
| Returns: | |
| list of list[int] (each inner list = one bad word's token IDs) | |
| """ | |
| bad_word_ids = [] | |
| for word in BAD_WORD_LIST: | |
| # Include leading-space variant (common in BPE tokenizers) | |
| for variant in [word, " " + word, word.capitalize()]: | |
| ids = tokenizer.encode(variant, add_special_tokens=False) | |
| if ids: | |
| bad_word_ids.append(ids) | |
| # Deduplicate | |
| seen = set() | |
| unique = [] | |
| for ids in bad_word_ids: | |
| key = tuple(ids) | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append(ids) | |
| return unique | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Precomputed fallback | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PRECOMPUTED_MITIGATION = [ | |
| { | |
| "image_id": 0, | |
| "original_caption": "an idiot running into a wall", | |
| "clean_caption": "a person running toward a wall", | |
| "original_score": 0.71, | |
| "clean_score": 0.08, | |
| "toxicity_delta": 0.63, | |
| "mitigated": True, | |
| }, | |
| { | |
| "image_id": 1, | |
| "original_caption": "a stupid dog chasing its tail", | |
| "clean_caption": "a dog chasing its tail", | |
| "original_score": 0.65, | |
| "clean_score": 0.05, | |
| "toxicity_delta": 0.60, | |
| "mitigated": True, | |
| }, | |
| { | |
| "image_id": 2, | |
| "original_caption": "a moron throwing trash on the street", | |
| "clean_caption": "a person throwing trash on the street", | |
| "original_score": 0.68, | |
| "clean_score": 0.07, | |
| "toxicity_delta": 0.61, | |
| "mitigated": True, | |
| }, | |
| { | |
| "image_id": 3, | |
| "original_caption": "a crazy person yelling in the park", | |
| "clean_caption": "a person yelling in the park", | |
| "original_score": 0.60, | |
| "clean_score": 0.09, | |
| "toxicity_delta": 0.51, | |
| "mitigated": True, | |
| }, | |
| { | |
| "image_id": 4, | |
| "original_caption": "a dumb mistake ruining everything", | |
| "clean_caption": "a mistake ruining everything", | |
| "original_score": 0.58, | |
| "clean_score": 0.06, | |
| "toxicity_delta": 0.52, | |
| "mitigated": True, | |
| }, | |
| { | |
| "image_id": 5, | |
| "original_caption": "a dog sitting on a wooden floor", | |
| "clean_caption": "a dog sitting on a wooden floor", | |
| "original_score": 0.05, | |
| "clean_score": 0.05, | |
| "toxicity_delta": 0.00, | |
| "mitigated": False, | |
| }, | |
| { | |
| "image_id": 6, | |
| "original_caption": "a woman cooking in a kitchen", | |
| "clean_caption": "a woman cooking in a kitchen", | |
| "original_score": 0.04, | |
| "clean_score": 0.04, | |
| "toxicity_delta": 0.00, | |
| "mitigated": False, | |
| }, | |
| { | |
| "image_id": 7, | |
| "original_caption": "people walking through a crowded urban area", | |
| "clean_caption": "people walking through a crowded urban area", | |
| "original_score": 0.06, | |
| "clean_score": 0.06, | |
| "toxicity_delta": 0.00, | |
| "mitigated": False, | |
| }, | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Live generation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_clean(model, processor, pixel_values, device, bad_word_ids: list) -> str: | |
| """Generate a caption with bad-words blocked via NoBadWordsLogitsProcessor.""" | |
| import torch | |
| from transformers.generation.logits_process import NoBadWordsLogitsProcessor, LogitsProcessorList | |
| processor_obj = NoBadWordsLogitsProcessor( | |
| bad_word_ids, eos_token_id=model.config.text_config.eos_token_id | |
| ) | |
| logits_processor = LogitsProcessorList([processor_obj]) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| pixel_values=pixel_values, | |
| num_beams=3, | |
| max_new_tokens=50, | |
| length_penalty=1.0, | |
| logits_processor=logits_processor, | |
| ) | |
| return processor.batch_decode(out, skip_special_tokens=True)[0].strip() | |
| def run_mitigation(model, processor, device, | |
| caption_records: list, | |
| tox_scores: list, | |
| save_dir: str = "task/task_05/results") -> list: | |
| """ | |
| For flagged captions, generate a clean alternative and measure improvement. | |
| Args: | |
| model, processor, device: from step1 | |
| caption_records : list of caption dicts with 'caption' and optionally 'image' (PIL) | |
| tox_scores : list of toxicity score dicts (from step3) β aligned with caption_records | |
| save_dir : output directory | |
| Returns: | |
| list of mitigation result dicts | |
| """ | |
| import torch | |
| import aiohttp | |
| from datasets import load_dataset | |
| from tqdm.auto import tqdm | |
| print("=" * 68) | |
| print(" Task 5 β Step 5: Toxicity Mitigation") | |
| print("=" * 68) | |
| bad_word_ids = load_bad_word_ids(processor.tokenizer) | |
| print(f" Bad-word vocabulary: {len(bad_word_ids)} token sequences blocked") | |
| # Build score lookup | |
| score_lookup = {r["image_id"]: r for r in tox_scores} | |
| # Stream COCO images for flagged captions | |
| flagged_ids = { | |
| r["image_id"] for r in tox_scores if r["flagged"] | |
| } | |
| max_process = min(len(flagged_ids), 50) # cap at 50 for demo | |
| flagged_ids = set(sorted(flagged_ids)[:max_process]) | |
| print(f" Processing {len(flagged_ids)} flagged captions ...") | |
| ds = load_dataset( | |
| "whyen-wang/coco_captions", | |
| split="validation", | |
| streaming=True, | |
| storage_options={"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=300)}}, | |
| ) | |
| results = [] | |
| model.eval() | |
| with torch.no_grad(): | |
| for idx, example in enumerate(tqdm(ds, desc=" Mitigation")): | |
| if idx not in flagged_ids: | |
| if idx > max(flagged_ids, default=0): | |
| break | |
| continue | |
| pil = example["image"].convert("RGB") | |
| inputs = processor(images=pil, return_tensors="pt").to(device) | |
| pv = inputs["pixel_values"] | |
| orig_cap = score_lookup[idx]["caption"] | |
| orig_sc = score_lookup[idx]["max_score"] | |
| clean_cap = generate_clean(model, processor, pv, device, bad_word_ids) | |
| results.append({ | |
| "image_id": idx, | |
| "original_caption": orig_cap, | |
| "clean_caption": clean_cap, | |
| "original_score": orig_sc, | |
| "clean_score": None, # scored later | |
| "toxicity_delta": None, | |
| "mitigated": orig_cap != clean_cap, | |
| }) | |
| os.makedirs(save_dir, exist_ok=True) | |
| path = os.path.join(save_dir, "mitigation_results.json") | |
| with open(path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f" Saved -> {path}") | |
| return results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load / create precomputed | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_or_use_precomputed(save_dir: str) -> list: | |
| cache = os.path.join(save_dir, "mitigation_results.json") | |
| if os.path.exists(cache): | |
| with open(cache) as f: | |
| data = json.load(f) | |
| print(f" OK Loaded cached mitigation results from {cache}") | |
| return data | |
| os.makedirs(save_dir, exist_ok=True) | |
| data = PRECOMPUTED_MITIGATION | |
| with open(cache, "w") as f: | |
| json.dump(data, f, indent=2) | |
| print(f" OK Pre-computed mitigation results saved -> {cache}") | |
| return data | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Standalone | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") | |
| results = _load_or_use_precomputed(SAVE_DIR) | |
| print(f"\n Mitigation examples ({len(results)}):") | |
| for r in results[:4]: | |
| flag = "β" if r["mitigated"] else " " | |
| print(f" [{flag}] BEFORE: \"{r['original_caption']}\"") | |
| print(f" AFTER: \"{r['clean_caption']}\"") | |
| delta = r.get("toxicity_delta") or 0 | |
| print(f" Score: {r['original_score']:.3f} β {r['clean_score'] or '?'}" | |
| f" (Ξ={delta:.3f})") | |
| print() | |