| | import torch |
| | import torch.nn as nn |
| | import re |
| | import unicodedata |
| | from typing import Any, List, Union |
| | from transformers import LongformerPreTrainedModel, LongformerModel, AutoTokenizer |
| | from .configuration_longformer import LongformerIntentConfig |
| |
|
| | def clean_text(s: Any, normalization: str = "NFKC", flatten_whitespace: bool = True) -> Any: |
| | if not isinstance(s, str): return s |
| | s = s.replace("\r\n", "\n").replace("\r", "\n") |
| | for ch in ["\u2028", "\u2029"]: s = s.replace(ch, "\n") |
| | for ch in ["\xa0"]: s = s.replace(ch, " ") |
| | for ch in ["\u200b", "\ufeff", "\u180e"]: s = s.replace(ch, "") |
| | if normalization != "none": s = unicodedata.normalize(normalization, s) |
| | if flatten_whitespace: s = re.sub(r"\s+", " ", s).strip() |
| | else: s = re.sub(r"[ \t]+", " ", s).strip() |
| | return s |
| |
|
| | class LongformerClassificationHead(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | def forward(self, hidden_states, **kwargs): |
| | x = hidden_states[:, 0, :] |
| | x = self.dropout(x) |
| | x = self.dense(x) |
| | x = torch.tanh(x) |
| | x = self.dropout(x) |
| | x = self.out_proj(x) |
| | return x |
| |
|
| | class LongformerIntentModel(LongformerPreTrainedModel): |
| | config_class = LongformerIntentConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.longformer = LongformerModel(config) |
| | self.classifier = LongformerClassificationHead(config) |
| | |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) |
| | |
| | self.post_init() |
| |
|
| | def forward(self, input_ids, attention_mask=None, global_attention_mask=None, labels=None): |
| | outputs = self.longformer( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | global_attention_mask=global_attention_mask |
| | ) |
| | logits = self.classifier(outputs[0]) |
| | |
| | |
| | return {"logits": logits} |
| |
|
| | def predict(self, texts: Union[str, List[str]], batch_size: int = 8, device: str = None): |
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | self.to(device) |
| | self.eval() |
| | |
| | if isinstance(texts, str): texts = [texts] |
| | |
| | all_results = [] |
| | for i in range(0, len(texts), batch_size): |
| | batch_texts = [clean_text(t) for t in texts[i : i + batch_size]] |
| | |
| | |
| | enc = self.tokenizer( |
| | batch_texts, |
| | padding=True, |
| | truncation=True, |
| | max_length=self.config.max_position_embeddings, |
| | return_tensors="pt" |
| | ).to(device) |
| |
|
| | global_mask = torch.zeros_like(enc["input_ids"]) |
| | global_mask[:, 0] = 1 |
| |
|
| | with torch.no_grad(): |
| | outputs = self.forward( |
| | input_ids=enc["input_ids"], |
| | attention_mask=enc["attention_mask"], |
| | global_attention_mask=global_mask |
| | ) |
| |
|
| | probs = torch.softmax(outputs["logits"], dim=-1).cpu().numpy() |
| |
|
| | for row in probs: |
| | pct = (row * 100).round().astype(int) |
| | diff = 100 - pct.sum() |
| | if diff != 0: |
| | pct[pct.argmax()] += diff |
| | all_results.append(dict(zip(self.config.intent_columns, pct.tolist()))) |
| | |
| | return all_results |