| | """ |
| | Custom inference handler for the arxiv-classifier PEFT adapter. |
| | |
| | This handler loads a LLaMA-3-8B base model with a LoRA adapter fine-tuned |
| | for arXiv paper classification into 150 subfields. |
| | """ |
| |
|
| | from typing import Dict, List, Any |
| | import torch |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig |
| | from peft import PeftModel |
| |
|
| | |
| | INVERSE_SUBFIELD_MAP = { |
| | 0: "astro-ph", 1: "astro-ph.CO", 2: "astro-ph.EP", 3: "astro-ph.GA", |
| | 4: "astro-ph.HE", 5: "astro-ph.IM", 6: "astro-ph.SR", 7: "cond-mat.dis-nn", |
| | 8: "cond-mat.mes-hall", 9: "cond-mat.mtrl-sci", 10: "cond-mat.other", |
| | 11: "cond-mat.quant-gas", 12: "cond-mat.soft", 13: "cond-mat.stat-mech", |
| | 14: "cond-mat.str-el", 15: "cond-mat.supr-con", 16: "cs.AI", 17: "cs.AR", |
| | 18: "cs.CC", 19: "cs.CE", 20: "cs.CG", 21: "cs.CL", 22: "cs.CR", 23: "cs.CV", |
| | 24: "cs.CY", 25: "cs.DB", 26: "cs.DC", 27: "cs.DL", 28: "cs.DM", 29: "cs.DS", |
| | 30: "cs.ET", 31: "cs.FL", 32: "cs.GL", 33: "cs.GR", 34: "cs.GT", 35: "cs.HC", |
| | 36: "cs.IR", 37: "cs.IT", 38: "cs.LG", 39: "cs.LO", 40: "cs.MA", 41: "cs.MM", |
| | 42: "cs.MS", 43: "cs.NE", 44: "cs.NI", 45: "cs.OH", 46: "cs.OS", 47: "cs.PF", |
| | 48: "cs.PL", 49: "cs.RO", 50: "cs.SC", 51: "cs.SD", 52: "cs.SE", 53: "cs.SI", |
| | 54: "econ.EM", 55: "econ.GN", 56: "econ.TH", 57: "eess.AS", 58: "eess.IV", |
| | 59: "eess.SP", 60: "eess.SY", 61: "gr-qc", 62: "hep-ex", 63: "hep-lat", |
| | 64: "hep-ph", 65: "hep-th", 66: "math-ph", 67: "math.AC", 68: "math.AG", |
| | 69: "math.AP", 70: "math.AT", 71: "math.CA", 72: "math.CO", 73: "math.CT", |
| | 74: "math.CV", 75: "math.DG", 76: "math.DS", 77: "math.FA", 78: "math.GM", |
| | 79: "math.GN", 80: "math.GR", 81: "math.GT", 82: "math.HO", 83: "math.KT", |
| | 84: "math.LO", 85: "math.MG", 86: "math.NA", 87: "math.NT", 88: "math.OA", |
| | 89: "math.OC", 90: "math.PR", 91: "math.QA", 92: "math.RA", 93: "math.RT", |
| | 94: "math.SG", 95: "math.SP", 96: "math.ST", 97: "nlin.AO", 98: "nlin.CD", |
| | 99: "nlin.CG", 100: "nlin.PS", 101: "nlin.SI", 102: "nucl-ex", 103: "nucl-th", |
| | 104: "physics.acc-ph", 105: "physics.ao-ph", 106: "physics.app-ph", |
| | 107: "physics.atm-clus", 108: "physics.atom-ph", 109: "physics.bio-ph", |
| | 110: "physics.chem-ph", 111: "physics.class-ph", 112: "physics.comp-ph", |
| | 113: "physics.data-an", 114: "physics.ed-ph", 115: "physics.flu-dyn", |
| | 116: "physics.gen-ph", 117: "physics.geo-ph", 118: "physics.hist-ph", |
| | 119: "physics.ins-det", 120: "physics.med-ph", 121: "physics.optics", |
| | 122: "physics.plasm-ph", 123: "physics.pop-ph", 124: "physics.soc-ph", |
| | 125: "physics.space-ph", 126: "q-bio.BM", 127: "q-bio.CB", 128: "q-bio.GN", |
| | 129: "q-bio.MN", 130: "q-bio.NC", 131: "q-bio.OT", 132: "q-bio.PE", |
| | 133: "q-bio.QM", 134: "q-bio.SC", 135: "q-bio.TO", 136: "q-fin.CP", |
| | 137: "q-fin.GN", 138: "q-fin.MF", 139: "q-fin.PM", 140: "q-fin.PR", |
| | 141: "q-fin.RM", 142: "q-fin.ST", 143: "q-fin.TR", 144: "quant-ph", |
| | 145: "stat.AP", 146: "stat.CO", 147: "stat.ME", 148: "stat.ML", 149: "stat.OT" |
| | } |
| |
|
| | N_SUBFIELDS = 150 |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the model and tokenizer. |
| | |
| | Args: |
| | path: Path to the model repository (adapter files) |
| | """ |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | base_model_name = "meta-llama/Meta-Llama-3-8B" |
| | self.max_length = 2048 |
| |
|
| | |
| | quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
| |
|
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| |
|
| | |
| | base_model = AutoModelForSequenceClassification.from_pretrained( |
| | base_model_name, |
| | quantization_config=quantization_config, |
| | num_labels=N_SUBFIELDS, |
| | device_map="auto", |
| | ) |
| | base_model.config.pad_token_id = self.tokenizer.pad_token_id |
| |
|
| | |
| | self.model = PeftModel.from_pretrained(base_model, path) |
| | self.model.eval() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Run inference on the input data. |
| | |
| | Args: |
| | data: Dictionary containing: |
| | - inputs (str or List[str]): The text(s) to classify |
| | - top_k (int, optional): Number of top predictions to return (default: 5) |
| | - return_all_scores (bool, optional): Return scores for all classes (default: False) |
| | |
| | Returns: |
| | List of predictions with labels and scores |
| | """ |
| | |
| | inputs = data.get("inputs", data) |
| | if isinstance(inputs, str): |
| | inputs = [inputs] |
| |
|
| | top_k = data.get("top_k", 5) |
| | return_all_scores = data.get("return_all_scores", False) |
| |
|
| | |
| | encoded = self.tokenizer( |
| | inputs, |
| | padding="max_length", |
| | max_length=self.max_length, |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| |
|
| | |
| | input_ids = encoded["input_ids"].to(self.device) |
| | attention_mask = encoded["attention_mask"].to(self.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
| | logits = outputs.logits |
| |
|
| | |
| | probs = torch.softmax(logits, dim=-1) |
| |
|
| | results = [] |
| | for i in range(len(inputs)): |
| | if return_all_scores: |
| | |
| | scores = probs[i].cpu().tolist() |
| | result = [ |
| | {"label": INVERSE_SUBFIELD_MAP[j], "score": scores[j]} |
| | for j in range(N_SUBFIELDS) |
| | ] |
| | else: |
| | |
| | top_probs, top_indices = torch.topk(probs[i], min(top_k, N_SUBFIELDS)) |
| | result = [ |
| | {"label": INVERSE_SUBFIELD_MAP[idx.item()], "score": prob.item()} |
| | for prob, idx in zip(top_probs, top_indices) |
| | ] |
| | results.append(result) |
| |
|
| | |
| | if len(results) == 1: |
| | return results[0] |
| | return results |
| |
|