| import os |
| import json |
| import argparse |
| import logging |
| from tqdm import tqdm |
| from typing import List, Dict |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
| |
| |
| |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| MODEL_PATH = os.path.join(SCRIPT_DIR, "model") |
| LABEL_MAP_PATH = os.path.join(SCRIPT_DIR, "label_to_id.json") |
|
|
|
|
| |
| |
| |
|
|
| def setup_logging(output_dir): |
| os.makedirs(output_dir, exist_ok=True) |
| log_path = os.path.join(output_dir, "language_classifier.log") |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(message)s", |
| handlers=[ |
| logging.FileHandler(log_path), |
| logging.StreamHandler() |
| ], |
| ) |
|
|
| logging.info(f"Logging to: {log_path}") |
|
|
|
|
| |
| |
| |
|
|
| def setup_distributed(): |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| dist.init_process_group(backend="nccl") |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| torch.cuda.set_device(local_rank) |
| return True, rank, world_size, local_rank |
| return False, 0, 1, 0 |
|
|
|
|
| def is_main_process(): |
| return ( |
| not dist.is_available() |
| or not dist.is_initialized() |
| or dist.get_rank() == 0 |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def find_all_jsonl_files(path: str) -> List[str]: |
| if os.path.isfile(path): |
| if not path.endswith(".jsonl"): |
| raise ValueError(f"Input file must be .jsonl: {path}") |
| return [path] |
|
|
| if not os.path.isdir(path): |
| raise ValueError(f"Input path does not exist: {path}") |
|
|
| files = [] |
| for root, _, filenames in os.walk(path): |
| for fn in filenames: |
| if fn.endswith(".jsonl"): |
| files.append(os.path.join(root, fn)) |
|
|
| if not files: |
| raise RuntimeError(f"No .jsonl files found inside: {path}") |
|
|
| return sorted(files) |
|
|
|
|
| |
| |
| |
|
|
| class JsonlIterableDataset(torch.utils.data.IterableDataset): |
| def __init__(self, input_path: str, text_key: str, rank: int, world_size: int): |
| self.files = find_all_jsonl_files(input_path) |
| self.text_key = text_key |
| self.rank = rank |
| self.world_size = world_size |
|
|
| def __iter__(self): |
| worker_info = torch.utils.data.get_worker_info() |
| worker_id = worker_info.id if worker_info else 0 |
| num_workers = worker_info.num_workers if worker_info else 1 |
|
|
| global_worker_id = self.rank * num_workers + worker_id |
| global_num_workers = self.world_size * num_workers |
|
|
| json_loads = json.loads |
| text_key = self.text_key |
|
|
| for path in self.files: |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: |
| i = 0 |
| for line in f: |
| if i == global_worker_id: |
| try: |
| obj = json_loads(line) |
| except json.JSONDecodeError: |
| pass |
| else: |
| text = obj.get(text_key) |
| if isinstance(text, str) and text.strip(): |
| obj["__lc_text"] = text |
| yield obj |
|
|
| i += 1 |
| if i == global_num_workers: |
| i = 0 |
|
|
|
|
| |
| |
| |
|
|
| class Collator: |
| def __init__(self, tokenizer, max_length=512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __call__(self, batch): |
| if not batch: |
| return None |
|
|
| texts = [x["__lc_text"] for x in batch] |
|
|
| enc = self.tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt", |
| ) |
|
|
| return {"enc": enc, "raw": batch} |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser("Language Classifier Inference") |
|
|
| parser.add_argument("--input_path", required=True) |
| parser.add_argument("--output_path", required=True) |
| parser.add_argument("--text_key", required=True) |
|
|
| parser.add_argument("--batch_size", type=int, default=2048) |
| parser.add_argument("--max_length", type=int, default=512) |
| parser.add_argument("--num_workers", type=int, default=8) |
|
|
| args = parser.parse_args() |
|
|
| setup_logging(args.output_path) |
|
|
| |
| |
| |
| distributed, rank, world_size, local_rank = setup_distributed() |
| device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" |
|
|
| logging.info(f"Distributed={distributed} | World size={world_size}") |
|
|
| |
| |
| |
| if not os.path.isfile(LABEL_MAP_PATH): |
| raise RuntimeError(f"Missing label map: {LABEL_MAP_PATH}") |
|
|
| with open(LABEL_MAP_PATH, "r", encoding="utf-8") as f: |
| label_map = json.load(f) |
|
|
| id_to_label = {v: k for k, v in label_map.items()} |
|
|
| |
| |
| |
| if not os.path.isdir(MODEL_PATH): |
| raise RuntimeError(f"Model directory not found: {MODEL_PATH}") |
|
|
| logging.info(f"Loading model from {MODEL_PATH}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
| model.to(device) |
| model.eval() |
|
|
| |
| |
| |
| dataset = JsonlIterableDataset( |
| args.input_path, |
| args.text_key, |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| collate_fn=Collator(tokenizer, args.max_length), |
| pin_memory=True, |
| persistent_workers=True, |
| prefetch_factor=4, |
| ) |
|
|
| |
| |
| |
| outputs: Dict[int, List[dict]] = {k: [] for k in id_to_label.keys()} |
|
|
| |
| |
| |
| iterator = tqdm(dataloader, desc="Classifying") if is_main_process() else dataloader |
|
|
| with torch.no_grad(): |
| for batch in iterator: |
| if batch is None: |
| continue |
|
|
| try: |
| enc = {k: v.to(device) for k, v in batch["enc"].items()} |
| raw = batch["raw"] |
|
|
| logits = model(**enc).logits |
| preds = torch.argmax(logits, dim=-1).cpu().tolist() |
|
|
| for obj, pred in zip(raw, preds): |
| obj = dict(obj) |
| obj.pop("__lc_text", None) |
| obj["predicted_id"] = pred |
| obj["predicted_language"] = id_to_label[pred] |
| outputs[pred].append(obj) |
|
|
| except Exception as e: |
| logging.exception(f"Batch failed: {e}") |
|
|
| |
| |
| |
| os.makedirs(args.output_path, exist_ok=True) |
|
|
| for cls_id, cls_name in id_to_label.items(): |
| out_path = os.path.join( |
| args.output_path, |
| f"{cls_name}.rank{rank}.jsonl" |
| ) |
|
|
| logging.info(f"Writing {len(outputs[cls_id])} samples to {out_path}") |
|
|
| with open(out_path, "w", encoding="utf-8") as f: |
| for obj in outputs[cls_id]: |
| f.write(json.dumps(obj, ensure_ascii=False) + "\n") |
|
|
| if distributed: |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
| logging.info("Language classification completed successfully.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|