Commit
·
c36a6f3
1
Parent(s):
209e610
new model
Browse files- .DS_Store +0 -0
- README.md +222 -3
- __init__.py +0 -0
- config.json +139 -0
- configuration_stacked.py +101 -0
- generic_ner.py +788 -0
- label_map.json +1 -0
- modeling_stacked.py +245 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +7 -0
- test.py +46 -0
- test_ner.py +106 -0
- tokenizer.json +0 -0
- tokenizer_config.json +59 -0
- training_args.bin +3 -0
- vocab.txt +0 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
README.md
CHANGED
|
@@ -1,3 +1,222 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- fr
|
| 6 |
+
- de
|
| 7 |
+
tags:
|
| 8 |
+
- v1.0.0
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Model Card for `impresso-project/ner-stacked-bert-multilingual-light`
|
| 12 |
+
|
| 13 |
+
The **Impresso NER model** is a multilingual named entity recognition model trained for historical document processing. It is based on a stacked Transformer architecture and is designed to identify fine-grained and coarse-grained entity types in digitized historical texts, including names, titles, and locations.
|
| 14 |
+
|
| 15 |
+
## Model Details
|
| 16 |
+
|
| 17 |
+
### Model Description
|
| 18 |
+
|
| 19 |
+
- **Developed by:** EPFL from the [Impresso team](https://impresso-project.ch). The project is an interdisciplinary project focused on historical media analysis across languages, time, and modalities. Funded by the Swiss National Science Foundation ([CRSII5_173719](http://p3.snf.ch/project-173719), [CRSII5_213585](https://data.snf.ch/grants/grant/213585)) and the Luxembourg National Research Fund (grant No. 17498891).
|
| 20 |
+
- **Model type:** Stacked BERT-based token classification for named entity recognition
|
| 21 |
+
- **Languages:** French, German, English (with support for multilingual historical texts)
|
| 22 |
+
- **License:** [AGPL v3+](https://github.com/impresso/impresso-pyindexation/blob/master/LICENSE)
|
| 23 |
+
- **Finetuned from:** [`dbmdz/bert-medium-historic-multilingual-cased`](https://huggingface.co/dbmdz/bert-medium-historic-multilingual-cased)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
### Model Architecture
|
| 27 |
+
|
| 28 |
+
The model architecture consists of the following components:
|
| 29 |
+
- A **pre-trained BERT encoder** (multilingual historic BERT) as the base.
|
| 30 |
+
- **One or two Transformer encoder layers** stacked on top of the BERT encoder.
|
| 31 |
+
- A **Conditional Random Field (CRF)** decoder layer to model label dependencies.
|
| 32 |
+
- **Learned absolute positional embeddings** for improved handling of noisy inputs.
|
| 33 |
+
|
| 34 |
+
These additional Transformer layers help in mitigating the effects of OCR noise, spelling variation, and non-standard linguistic usage found in historical documents. The entire stack is fine-tuned end-to-end for token classification.
|
| 35 |
+
|
| 36 |
+
### Entity Types Supported
|
| 37 |
+
|
| 38 |
+
The model supports both coarse-grained and fine-grained entity types defined in the HIPE-2020/2022 guidelines. The output format of the model includes structured predictions with contextual and semantic details. Each prediction is a dictionary with the following fields:
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
{
|
| 42 |
+
'type': 'pers' | 'org' | 'loc' | 'time' | 'prod',
|
| 43 |
+
'confidence_ner': float, # Confidence score
|
| 44 |
+
'surface': str, # Surface form in text
|
| 45 |
+
'lOffset': int, # Start character offset
|
| 46 |
+
'rOffset': int, # End character offset
|
| 47 |
+
'name': str, # Optional: full name (for persons)
|
| 48 |
+
'title': str, # Optional: title (for persons)
|
| 49 |
+
'function': str # Optional: function (if detected)
|
| 50 |
+
}
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
#### Coarse-Grained Entity Types:
|
| 55 |
+
- **pers**: Person entities (individuals, collectives, authors)
|
| 56 |
+
- **org**: Organizations (administrative, enterprise, press agencies)
|
| 57 |
+
- **prod**: Products (media)
|
| 58 |
+
- **time**: Time expressions (absolute dates)
|
| 59 |
+
- **loc**: Locations (towns, regions, countries, physical, facilities)
|
| 60 |
+
|
| 61 |
+
If present in the text, surrounding an entity, model returns **person-specific attributes** such as:
|
| 62 |
+
- `name`: canonical full name
|
| 63 |
+
- `title`: honorific or title (e.g., "king", "chancellor")
|
| 64 |
+
- `function`: role or function in context (if available)
|
| 65 |
+
|
| 66 |
+
### Model Sources
|
| 67 |
+
|
| 68 |
+
- **Repository:** https://huggingface.co/impresso-project/ner-stacked-bert-multilingual
|
| 69 |
+
- **Paper:** [CoNLL 2020](https://aclanthology.org/2020.conll-1.35/)
|
| 70 |
+
- **Demo:** [Impresso project](https://impresso-project.ch)
|
| 71 |
+
|
| 72 |
+
## Uses
|
| 73 |
+
|
| 74 |
+
### Direct Use
|
| 75 |
+
|
| 76 |
+
The model is intended to be used directly with the Hugging Face `pipeline` for `token-classification`, specifically with `generic-ner` tasks on historical texts.
|
| 77 |
+
|
| 78 |
+
### Downstream Use
|
| 79 |
+
|
| 80 |
+
Can be used for downstream tasks such as:
|
| 81 |
+
- Historical information extraction
|
| 82 |
+
- Biographical reconstruction
|
| 83 |
+
- Place and person mention detection across historical archives
|
| 84 |
+
|
| 85 |
+
### Out-of-Scope Use
|
| 86 |
+
|
| 87 |
+
- Not suitable for contemporary named entity recognition in domains such as social media or modern news.
|
| 88 |
+
- Not optimized for OCR-free modern corpora.
|
| 89 |
+
|
| 90 |
+
## Bias, Risks, and Limitations
|
| 91 |
+
|
| 92 |
+
Due to training on historical documents, the model may reflect historical biases and inaccuracies. It may underperform on contemporary or non-European languages.
|
| 93 |
+
|
| 94 |
+
### Recommendations
|
| 95 |
+
|
| 96 |
+
- Users should be cautious of historical and typographical biases.
|
| 97 |
+
- Consider post-processing to filter false positives from OCR noise.
|
| 98 |
+
|
| 99 |
+
## How to Get Started with the Model
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
| 103 |
+
|
| 104 |
+
MODEL_NAME = "impresso-project/ner-stacked-bert-multilingual-light"
|
| 105 |
+
|
| 106 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 107 |
+
|
| 108 |
+
ner_pipeline = pipeline("generic-ner", model=MODEL_NAME, tokenizer=tokenizer, trust_remote_code=True, device='cpu')
|
| 109 |
+
|
| 110 |
+
sentence = "En l'an 1348, au plus fort des ravages de la peste noire à travers l'Europe, le Royaume de France se trouvait à la fois au bord du désespoir et face à une opportunité. À la cour du roi Philippe VI, les murs du Louvre étaient animés par les rapports sombres venus de Paris et des villes environnantes. La peste ne montrait aucun signe de répit, et le chancelier Guillaume de Nogaret, le conseiller le plus fidèle du roi, portait le lourd fardeau de gérer la survie du royaume."
|
| 111 |
+
entities = ner_pipeline(sentence)
|
| 112 |
+
print(entities)
|
| 113 |
+
```
|
| 114 |
+
#### Example Output
|
| 115 |
+
|
| 116 |
+
```json
|
| 117 |
+
[
|
| 118 |
+
{'type': 'time', 'confidence_ner': 85.0, 'surface': "an 1348", 'lOffset': 0, 'rOffset': 12},
|
| 119 |
+
{'type': 'loc', 'confidence_ner': 90.75, 'surface': "Europe", 'lOffset': 69, 'rOffset': 75},
|
| 120 |
+
{'type': 'loc', 'confidence_ner': 75.45, 'surface': "Royaume de France", 'lOffset': 80, 'rOffset': 97},
|
| 121 |
+
{'type': 'pers', 'confidence_ner': 85.27, 'surface': "roi Philippe VI", 'lOffset': 181, 'rOffset': 196, 'title': "roi", 'name': "roi Philippe VI"},
|
| 122 |
+
{'type': 'loc', 'confidence_ner': 30.59, 'surface': "Louvre", 'lOffset': 210, 'rOffset': 216},
|
| 123 |
+
{'type': 'loc', 'confidence_ner': 94.46, 'surface': "Paris", 'lOffset': 266, 'rOffset': 271},
|
| 124 |
+
{'type': 'pers', 'confidence_ner': 96.1, 'surface': "chancelier Guillaume de Nogaret", 'lOffset': 350, 'rOffset': 381, 'title': "chancelier", 'name': "Guillaume de Nogaret"},
|
| 125 |
+
{'type': 'loc', 'confidence_ner': 49.35, 'surface': "Royaume", 'lOffset': 80, 'rOffset': 87},
|
| 126 |
+
{'type': 'loc', 'confidence_ner': 24.18, 'surface': "France", 'lOffset': 91, 'rOffset': 97}
|
| 127 |
+
]
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
## Training Details
|
| 131 |
+
|
| 132 |
+
### Training Data
|
| 133 |
+
|
| 134 |
+
The model was trained on the Impresso HIPE-2020 dataset, a subset of the [HIPE-2022 corpus](https://github.com/hipe-eval/HIPE-2022-data), which includes richly annotated OCR-transcribed historical newspaper content.
|
| 135 |
+
|
| 136 |
+
### Training Procedure
|
| 137 |
+
|
| 138 |
+
#### Preprocessing
|
| 139 |
+
|
| 140 |
+
OCR content was cleaned and segmented. Entity types follow the HIPE-2020 typology.
|
| 141 |
+
|
| 142 |
+
#### Training Hyperparameters
|
| 143 |
+
|
| 144 |
+
- **Training regime:** Mixed precision (fp16)
|
| 145 |
+
- **Epochs:** 5
|
| 146 |
+
- **Max sequence length:** 512
|
| 147 |
+
- **Base model:** `dbmdz/bert-medium-historic-multilingual-cased`
|
| 148 |
+
- **Stacked Transformer layers:** 2
|
| 149 |
+
|
| 150 |
+
#### Speeds, Sizes, Times
|
| 151 |
+
|
| 152 |
+
- **Model size:** ~500MB
|
| 153 |
+
- **Training time:** ~1h on 1 GPU (NVIDIA TITAN X)
|
| 154 |
+
|
| 155 |
+
## Evaluation
|
| 156 |
+
|
| 157 |
+
#### Testing Data
|
| 158 |
+
|
| 159 |
+
Held-out portion of HIPE-2020 (French, German)
|
| 160 |
+
|
| 161 |
+
#### Metrics
|
| 162 |
+
|
| 163 |
+
- F1-score (micro, macro)
|
| 164 |
+
- Entity-level precision/recall
|
| 165 |
+
|
| 166 |
+
### Results
|
| 167 |
+
|
| 168 |
+
| Language | Precision | Recall | F1-score |
|
| 169 |
+
|----------|-----------|--------|----------|
|
| 170 |
+
| French | 84.2 | 81.6 | 82.9 |
|
| 171 |
+
| German | 82.0 | 78.7 | 80.3 |
|
| 172 |
+
|
| 173 |
+
#### Summary
|
| 174 |
+
|
| 175 |
+
The model performs robustly across noisy OCR historical content with support for fine-grained entity typologies.
|
| 176 |
+
|
| 177 |
+
## Environmental Impact
|
| 178 |
+
|
| 179 |
+
- **Hardware Type:** NVIDIA TITAN X (Pascal, 12GB)
|
| 180 |
+
- **Hours used:** ~1 hour
|
| 181 |
+
- **Cloud Provider:** EPFL, Switzerland
|
| 182 |
+
- **Carbon Emitted:** ~0.022 kg CO₂eq (estimated)
|
| 183 |
+
|
| 184 |
+
## Technical Specifications
|
| 185 |
+
|
| 186 |
+
### Model Architecture and Objective
|
| 187 |
+
|
| 188 |
+
Stacked BERT architecture with multitask token classification head supporting HIPE-type entity labels.
|
| 189 |
+
|
| 190 |
+
### Compute Infrastructure
|
| 191 |
+
|
| 192 |
+
#### Hardware
|
| 193 |
+
|
| 194 |
+
1x NVIDIA TITAN X (Pascal, 12GB)
|
| 195 |
+
|
| 196 |
+
#### Software
|
| 197 |
+
|
| 198 |
+
- Python 3.11
|
| 199 |
+
- PyTorch 2.0
|
| 200 |
+
- Transformers 4.36
|
| 201 |
+
|
| 202 |
+
## Citation
|
| 203 |
+
|
| 204 |
+
**BibTeX:**
|
| 205 |
+
|
| 206 |
+
```bibtex
|
| 207 |
+
@inproceedings{boros2020alleviating,
|
| 208 |
+
title={Alleviating digitization errors in named entity recognition for historical documents},
|
| 209 |
+
author={Boros, Emanuela and Hamdi, Ahmed and Pontes, Elvys Linhares and Cabrera-Diego, Luis-Adri{\'a}n and Moreno, Jose G and Sidere, Nicolas and Doucet, Antoine},
|
| 210 |
+
booktitle={Proceedings of the 24th conference on computational natural language learning},
|
| 211 |
+
pages={431--441},
|
| 212 |
+
year={2020}
|
| 213 |
+
}
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
## Contact
|
| 217 |
+
|
| 218 |
+
- Website: [https://impresso-project.ch](https://impresso-project.ch)
|
| 219 |
+
|
| 220 |
+
<p align="center">
|
| 221 |
+
<img src="https://github.com/impresso/impresso.github.io/blob/master/assets/images/3x1--Yellow-Impresso-Black-on-White--transparent.png?raw=true" width="300" alt="Impresso Logo"/>
|
| 222 |
+
</p>
|
__init__.py
ADDED
|
File without changes
|
config.json
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "experiments/model_dbmdz_bert_medium_historic_multilingual_cased_max_sequence_length_512_epochs_5_run_multitask.baseline.False2025/",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"ExtendedMultitaskTimeModelForTokenClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "configuration_stacked.ImpressoConfig",
|
| 9 |
+
"AutoModelForTokenClassification": "modeling_stacked.ExtendedMultitaskTimeModelForTokenClassification"
|
| 10 |
+
},
|
| 11 |
+
"classifier_dropout": null,
|
| 12 |
+
"custom_pipelines": {
|
| 13 |
+
"generic-ner": {
|
| 14 |
+
"impl": "generic_ner.ExtendedMultitaskTimeModelForTokenClassificationPipeline",
|
| 15 |
+
"pt": "AutoModelForTokenClassification"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"hidden_act": "gelu",
|
| 19 |
+
"hidden_dropout_prob": 0.1,
|
| 20 |
+
"hidden_size": 512,
|
| 21 |
+
"initializer_range": 0.02,
|
| 22 |
+
"intermediate_size": 2048,
|
| 23 |
+
"label_map": {
|
| 24 |
+
"NE-COARSE-LIT": {
|
| 25 |
+
"I-pers": 0,
|
| 26 |
+
"I-prod": 1,
|
| 27 |
+
"B-prod": 2,
|
| 28 |
+
"B-loc": 3,
|
| 29 |
+
"I-time": 4,
|
| 30 |
+
"B-pers": 5,
|
| 31 |
+
"B-org": 6,
|
| 32 |
+
"B-time": 7,
|
| 33 |
+
"I-loc": 8,
|
| 34 |
+
"O": 9,
|
| 35 |
+
"I-org": 10
|
| 36 |
+
},
|
| 37 |
+
"NE-FINE-COMP": {
|
| 38 |
+
"I-comp.title": 0,
|
| 39 |
+
"B-comp.title": 1,
|
| 40 |
+
"I-comp.function": 2,
|
| 41 |
+
"I-comp.name": 3,
|
| 42 |
+
"B-comp.function": 4,
|
| 43 |
+
"O": 5,
|
| 44 |
+
"B-comp.name": 6
|
| 45 |
+
}
|
| 46 |
+
},
|
| 47 |
+
"layer_norm_eps": 1e-12,
|
| 48 |
+
"max_position_embeddings": 512,
|
| 49 |
+
"model_type": "stacked_bert",
|
| 50 |
+
"num_attention_heads": 8,
|
| 51 |
+
"num_hidden_layers": 8,
|
| 52 |
+
"pad_token_id": 0,
|
| 53 |
+
"position_embedding_type": "absolute",
|
| 54 |
+
"pretrained_config": {
|
| 55 |
+
"_name_or_path": "dbmdz/bert-medium-historic-multilingual-cased",
|
| 56 |
+
"add_cross_attention": false,
|
| 57 |
+
"architectures": [
|
| 58 |
+
"BertForMaskedLM"
|
| 59 |
+
],
|
| 60 |
+
"attention_probs_dropout_prob": 0.1,
|
| 61 |
+
"bad_words_ids": null,
|
| 62 |
+
"begin_suppress_tokens": null,
|
| 63 |
+
"bos_token_id": null,
|
| 64 |
+
"chunk_size_feed_forward": 0,
|
| 65 |
+
"classifier_dropout": null,
|
| 66 |
+
"cross_attention_hidden_size": null,
|
| 67 |
+
"decoder_start_token_id": null,
|
| 68 |
+
"diversity_penalty": 0.0,
|
| 69 |
+
"do_sample": false,
|
| 70 |
+
"early_stopping": false,
|
| 71 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 72 |
+
"eos_token_id": null,
|
| 73 |
+
"exponential_decay_length_penalty": null,
|
| 74 |
+
"finetuning_task": null,
|
| 75 |
+
"forced_bos_token_id": null,
|
| 76 |
+
"forced_eos_token_id": null,
|
| 77 |
+
"hidden_act": "gelu",
|
| 78 |
+
"hidden_dropout_prob": 0.1,
|
| 79 |
+
"hidden_size": 512,
|
| 80 |
+
"id2label": {
|
| 81 |
+
"0": "LABEL_0",
|
| 82 |
+
"1": "LABEL_1"
|
| 83 |
+
},
|
| 84 |
+
"initializer_range": 0.02,
|
| 85 |
+
"intermediate_size": 2048,
|
| 86 |
+
"is_decoder": false,
|
| 87 |
+
"is_encoder_decoder": false,
|
| 88 |
+
"label2id": {
|
| 89 |
+
"LABEL_0": 0,
|
| 90 |
+
"LABEL_1": 1
|
| 91 |
+
},
|
| 92 |
+
"layer_norm_eps": 1e-12,
|
| 93 |
+
"length_penalty": 1.0,
|
| 94 |
+
"max_length": 20,
|
| 95 |
+
"max_position_embeddings": 512,
|
| 96 |
+
"min_length": 0,
|
| 97 |
+
"model_type": "bert",
|
| 98 |
+
"no_repeat_ngram_size": 0,
|
| 99 |
+
"num_attention_heads": 8,
|
| 100 |
+
"num_beam_groups": 1,
|
| 101 |
+
"num_beams": 1,
|
| 102 |
+
"num_hidden_layers": 8,
|
| 103 |
+
"num_return_sequences": 1,
|
| 104 |
+
"output_attentions": false,
|
| 105 |
+
"output_hidden_states": false,
|
| 106 |
+
"output_scores": false,
|
| 107 |
+
"pad_token_id": 0,
|
| 108 |
+
"position_embedding_type": "absolute",
|
| 109 |
+
"prefix": null,
|
| 110 |
+
"problem_type": null,
|
| 111 |
+
"pruned_heads": {},
|
| 112 |
+
"remove_invalid_values": false,
|
| 113 |
+
"repetition_penalty": 1.0,
|
| 114 |
+
"return_dict": true,
|
| 115 |
+
"return_dict_in_generate": false,
|
| 116 |
+
"sep_token_id": null,
|
| 117 |
+
"suppress_tokens": null,
|
| 118 |
+
"task_specific_params": null,
|
| 119 |
+
"temperature": 1.0,
|
| 120 |
+
"tf_legacy_loss": false,
|
| 121 |
+
"tie_encoder_decoder": false,
|
| 122 |
+
"tie_word_embeddings": true,
|
| 123 |
+
"tokenizer_class": null,
|
| 124 |
+
"top_k": 50,
|
| 125 |
+
"top_p": 1.0,
|
| 126 |
+
"torch_dtype": null,
|
| 127 |
+
"torchscript": false,
|
| 128 |
+
"type_vocab_size": 2,
|
| 129 |
+
"typical_p": 1.0,
|
| 130 |
+
"use_bfloat16": false,
|
| 131 |
+
"use_cache": true,
|
| 132 |
+
"vocab_size": 32000
|
| 133 |
+
},
|
| 134 |
+
"torch_dtype": "float32",
|
| 135 |
+
"transformers_version": "4.40.0.dev0",
|
| 136 |
+
"type_vocab_size": 2,
|
| 137 |
+
"use_cache": true,
|
| 138 |
+
"vocab_size": 32000
|
| 139 |
+
}
|
configuration_stacked.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ImpressoConfig(PretrainedConfig):
|
| 6 |
+
model_type = "stacked_bert"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
vocab_size=30522,
|
| 11 |
+
hidden_size=768,
|
| 12 |
+
num_hidden_layers=12,
|
| 13 |
+
num_attention_heads=12,
|
| 14 |
+
intermediate_size=3072,
|
| 15 |
+
hidden_act="gelu",
|
| 16 |
+
hidden_dropout_prob=0.1,
|
| 17 |
+
attention_probs_dropout_prob=0.1,
|
| 18 |
+
max_position_embeddings=512,
|
| 19 |
+
type_vocab_size=2,
|
| 20 |
+
initializer_range=0.02,
|
| 21 |
+
layer_norm_eps=1e-12,
|
| 22 |
+
pad_token_id=0,
|
| 23 |
+
position_embedding_type="absolute",
|
| 24 |
+
use_cache=True,
|
| 25 |
+
classifier_dropout=None,
|
| 26 |
+
pretrained_config=None,
|
| 27 |
+
values_override=None,
|
| 28 |
+
label_map=None,
|
| 29 |
+
**kwargs,
|
| 30 |
+
):
|
| 31 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 32 |
+
|
| 33 |
+
self.vocab_size = vocab_size
|
| 34 |
+
self.hidden_size = hidden_size
|
| 35 |
+
self.num_hidden_layers = num_hidden_layers
|
| 36 |
+
self.num_attention_heads = num_attention_heads
|
| 37 |
+
self.hidden_act = hidden_act
|
| 38 |
+
self.intermediate_size = intermediate_size
|
| 39 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 40 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 41 |
+
self.max_position_embeddings = max_position_embeddings
|
| 42 |
+
self.type_vocab_size = type_vocab_size
|
| 43 |
+
self.initializer_range = initializer_range
|
| 44 |
+
self.layer_norm_eps = layer_norm_eps
|
| 45 |
+
self.position_embedding_type = position_embedding_type
|
| 46 |
+
self.use_cache = use_cache
|
| 47 |
+
self.classifier_dropout = classifier_dropout
|
| 48 |
+
self.pretrained_config = pretrained_config
|
| 49 |
+
self.label_map = label_map
|
| 50 |
+
|
| 51 |
+
self.values_override = values_override or {}
|
| 52 |
+
self.outputs = {
|
| 53 |
+
"logits": {"shape": [None, None, self.hidden_size], "dtype": "float32"}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def is_torch_support_available(cls):
|
| 58 |
+
"""
|
| 59 |
+
Indicate whether Torch support is available for this configuration.
|
| 60 |
+
Required for compatibility with certain parts of the Transformers library.
|
| 61 |
+
"""
|
| 62 |
+
return True
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def patch_ops(self):
|
| 66 |
+
"""
|
| 67 |
+
A method required by some Hugging Face utilities to modify operator mappings.
|
| 68 |
+
Currently, it performs no operation and is included for compatibility.
|
| 69 |
+
Args:
|
| 70 |
+
ops: A dictionary of operations to potentially patch.
|
| 71 |
+
Returns:
|
| 72 |
+
The (unmodified) ops dictionary.
|
| 73 |
+
"""
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
def generate_dummy_inputs(self, tokenizer, batch_size=1, seq_length=8, framework="pt"):
|
| 77 |
+
"""
|
| 78 |
+
Generate dummy inputs for testing or export.
|
| 79 |
+
Args:
|
| 80 |
+
tokenizer: The tokenizer used to tokenize inputs.
|
| 81 |
+
batch_size: Number of input samples in the batch.
|
| 82 |
+
seq_length: Length of each sequence.
|
| 83 |
+
framework: Framework ("pt" for PyTorch, "tf" for TensorFlow).
|
| 84 |
+
Returns:
|
| 85 |
+
Dummy inputs as a dictionary.
|
| 86 |
+
"""
|
| 87 |
+
if framework == "pt":
|
| 88 |
+
input_ids = torch.randint(
|
| 89 |
+
low=0,
|
| 90 |
+
high=self.vocab_size,
|
| 91 |
+
size=(batch_size, seq_length),
|
| 92 |
+
dtype=torch.long
|
| 93 |
+
)
|
| 94 |
+
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
|
| 95 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError("Framework '{}' not supported.".format(framework))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Register the configuration with the transformers library
|
| 101 |
+
ImpressoConfig.register_for_auto_class()
|
generic_ner.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from transformers import Pipeline
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import nltk
|
| 6 |
+
|
| 7 |
+
nltk.download("averaged_perceptron_tagger")
|
| 8 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
| 9 |
+
nltk.download("stopwords")
|
| 10 |
+
from nltk.chunk import conlltags2tree
|
| 11 |
+
from nltk import pos_tag
|
| 12 |
+
from nltk.tree import Tree
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import re, string
|
| 15 |
+
|
| 16 |
+
stop_words = set(nltk.corpus.stopwords.words("english"))
|
| 17 |
+
DEBUG = False
|
| 18 |
+
punctuation = (
|
| 19 |
+
string.punctuation
|
| 20 |
+
+ "«»—…“”"
|
| 21 |
+
+ "—."
|
| 22 |
+
+ "–"
|
| 23 |
+
+ "’"
|
| 24 |
+
+ "‘"
|
| 25 |
+
+ "´"
|
| 26 |
+
+ "•"
|
| 27 |
+
+ "°"
|
| 28 |
+
+ "»"
|
| 29 |
+
+ "“"
|
| 30 |
+
+ "”"
|
| 31 |
+
+ "–"
|
| 32 |
+
+ "—"
|
| 33 |
+
+ "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# List of additional "strange" punctuation marks
|
| 37 |
+
# additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
WHITESPACE_RULES = {
|
| 41 |
+
"fr": {
|
| 42 |
+
"pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"],
|
| 43 |
+
"pct_no_ws_after": ["(", "[", "{"],
|
| 44 |
+
"pct_no_ws_before_after": ["'", "-"],
|
| 45 |
+
"pct_number": [".", ","],
|
| 46 |
+
},
|
| 47 |
+
"de": {
|
| 48 |
+
"pct_no_ws_before": [
|
| 49 |
+
".",
|
| 50 |
+
",",
|
| 51 |
+
")",
|
| 52 |
+
"]",
|
| 53 |
+
"}",
|
| 54 |
+
"°",
|
| 55 |
+
"...",
|
| 56 |
+
"?",
|
| 57 |
+
"!",
|
| 58 |
+
":",
|
| 59 |
+
";",
|
| 60 |
+
".-",
|
| 61 |
+
"%",
|
| 62 |
+
],
|
| 63 |
+
"pct_no_ws_after": ["(", "[", "{"],
|
| 64 |
+
"pct_no_ws_before_after": ["'", "-"],
|
| 65 |
+
"pct_number": [".", ","],
|
| 66 |
+
},
|
| 67 |
+
"other": {
|
| 68 |
+
"pct_no_ws_before": [
|
| 69 |
+
".",
|
| 70 |
+
",",
|
| 71 |
+
")",
|
| 72 |
+
"]",
|
| 73 |
+
"}",
|
| 74 |
+
"°",
|
| 75 |
+
"...",
|
| 76 |
+
"?",
|
| 77 |
+
"!",
|
| 78 |
+
":",
|
| 79 |
+
";",
|
| 80 |
+
".-",
|
| 81 |
+
"%",
|
| 82 |
+
],
|
| 83 |
+
"pct_no_ws_after": ["(", "[", "{"],
|
| 84 |
+
"pct_no_ws_before_after": ["'", "-"],
|
| 85 |
+
"pct_number": [".", ","],
|
| 86 |
+
},
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# def tokenize(text: str, language: str = "other") -> list[str]:
|
| 91 |
+
# """Apply whitespace rules to the given text and language, separating it into tokens.
|
| 92 |
+
#
|
| 93 |
+
# Args:
|
| 94 |
+
# text (str): The input text to separate into a list of tokens.
|
| 95 |
+
# language (str): Language of the text.
|
| 96 |
+
#
|
| 97 |
+
# Returns:
|
| 98 |
+
# list[str]: List of tokens with punctuation as separate tokens.
|
| 99 |
+
# """
|
| 100 |
+
# # text = add_spaces_around_punctuation(text)
|
| 101 |
+
# if not text:
|
| 102 |
+
# return []
|
| 103 |
+
#
|
| 104 |
+
# if language not in WHITESPACE_RULES:
|
| 105 |
+
# # Default behavior for languages without specific rules:
|
| 106 |
+
# # tokenize using standard whitespace splitting
|
| 107 |
+
# language = "other"
|
| 108 |
+
#
|
| 109 |
+
# wsrules = WHITESPACE_RULES[language]
|
| 110 |
+
# tokenized_text = []
|
| 111 |
+
# current_token = ""
|
| 112 |
+
#
|
| 113 |
+
# for char in text:
|
| 114 |
+
# if char in wsrules["pct_no_ws_before_after"]:
|
| 115 |
+
# if current_token:
|
| 116 |
+
# tokenized_text.append(current_token)
|
| 117 |
+
# tokenized_text.append(char)
|
| 118 |
+
# current_token = ""
|
| 119 |
+
# elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
|
| 120 |
+
# if current_token:
|
| 121 |
+
# tokenized_text.append(current_token)
|
| 122 |
+
# tokenized_text.append(char)
|
| 123 |
+
# current_token = ""
|
| 124 |
+
# elif char.isspace():
|
| 125 |
+
# if current_token:
|
| 126 |
+
# tokenized_text.append(current_token)
|
| 127 |
+
# current_token = ""
|
| 128 |
+
# else:
|
| 129 |
+
# current_token += char
|
| 130 |
+
#
|
| 131 |
+
# if current_token:
|
| 132 |
+
# tokenized_text.append(current_token)
|
| 133 |
+
#
|
| 134 |
+
# return tokenized_text
|
| 135 |
+
|
| 136 |
+
def normalize_text(text):
|
| 137 |
+
# Remove spaces and tabs for the search but keep newline characters
|
| 138 |
+
return re.sub(r"[ \t]+", "", text)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def find_entity_indices(article_text, search_text):
|
| 142 |
+
# Normalize texts by removing spaces and tabs
|
| 143 |
+
normalized_article = normalize_text(article_text)
|
| 144 |
+
normalized_search = normalize_text(search_text)
|
| 145 |
+
|
| 146 |
+
# Initialize a list to hold all start and end indices
|
| 147 |
+
indices = []
|
| 148 |
+
|
| 149 |
+
# Find all occurrences of the search text in the normalized article text
|
| 150 |
+
start_index = 0
|
| 151 |
+
while True:
|
| 152 |
+
start_index = normalized_article.find(normalized_search, start_index)
|
| 153 |
+
if start_index == -1:
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
# Calculate the actual start and end indices in the original article text
|
| 157 |
+
original_chars = 0
|
| 158 |
+
original_start_index = 0
|
| 159 |
+
for i in range(start_index):
|
| 160 |
+
while article_text[original_start_index] in (" ", "\t"):
|
| 161 |
+
original_start_index += 1
|
| 162 |
+
if article_text[original_start_index] not in (" ", "\t", "\n"):
|
| 163 |
+
original_chars += 1
|
| 164 |
+
original_start_index += 1
|
| 165 |
+
|
| 166 |
+
original_end_index = original_start_index
|
| 167 |
+
search_chars = 0
|
| 168 |
+
while search_chars < len(normalized_search):
|
| 169 |
+
if article_text[original_end_index] not in (" ", "\t", "\n"):
|
| 170 |
+
search_chars += 1
|
| 171 |
+
original_end_index += 1 # Increment to include the last character
|
| 172 |
+
|
| 173 |
+
# Append the found indices to the list
|
| 174 |
+
if article_text[original_start_index] == " ":
|
| 175 |
+
original_start_index += 1
|
| 176 |
+
indices.append((original_start_index, original_end_index))
|
| 177 |
+
|
| 178 |
+
# Move start_index to the next position to continue searching
|
| 179 |
+
start_index += 1
|
| 180 |
+
|
| 181 |
+
return indices
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_entities(tokens, tags, confidences, text):
|
| 185 |
+
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
|
| 186 |
+
pos_tags = [pos for token, pos in pos_tag(tokens)]
|
| 187 |
+
|
| 188 |
+
for i in range(1, len(tags)):
|
| 189 |
+
# If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-'
|
| 190 |
+
if tags[i].startswith("B-") and tags[i - 1].startswith("I-"):
|
| 191 |
+
tags[i] = "I-" + tags[i][2:] # Change 'B-' to 'I-' for the same entity type
|
| 192 |
+
|
| 193 |
+
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
|
| 194 |
+
ne_tree = conlltags2tree(conlltags)
|
| 195 |
+
|
| 196 |
+
entities = []
|
| 197 |
+
idx: int = 0
|
| 198 |
+
already_done = []
|
| 199 |
+
for subtree in ne_tree:
|
| 200 |
+
# skipping 'O' tags
|
| 201 |
+
if isinstance(subtree, Tree):
|
| 202 |
+
original_label = subtree.label()
|
| 203 |
+
original_string = " ".join([token for token, pos in subtree.leaves()])
|
| 204 |
+
|
| 205 |
+
for indices in find_entity_indices(text, original_string):
|
| 206 |
+
entity_start_position = indices[0]
|
| 207 |
+
entity_end_position = indices[1]
|
| 208 |
+
if (
|
| 209 |
+
"_".join(
|
| 210 |
+
[original_label, original_string, str(entity_start_position)]
|
| 211 |
+
)
|
| 212 |
+
in already_done
|
| 213 |
+
):
|
| 214 |
+
continue
|
| 215 |
+
else:
|
| 216 |
+
already_done.append(
|
| 217 |
+
"_".join(
|
| 218 |
+
[
|
| 219 |
+
original_label,
|
| 220 |
+
original_string,
|
| 221 |
+
str(entity_start_position),
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
if len(text[entity_start_position:entity_end_position].strip()) < len(
|
| 226 |
+
text[entity_start_position:entity_end_position]
|
| 227 |
+
):
|
| 228 |
+
entity_start_position = (
|
| 229 |
+
entity_start_position
|
| 230 |
+
+ len(text[entity_start_position:entity_end_position])
|
| 231 |
+
- len(text[entity_start_position:entity_end_position].strip())
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
entities.append(
|
| 235 |
+
{
|
| 236 |
+
"type": original_label,
|
| 237 |
+
"confidence_ner": round(
|
| 238 |
+
np.average(confidences[idx: idx + len(subtree)]), 2
|
| 239 |
+
),
|
| 240 |
+
"index": (idx, idx + len(subtree)),
|
| 241 |
+
"surface": text[
|
| 242 |
+
entity_start_position:entity_end_position
|
| 243 |
+
], # original_string,
|
| 244 |
+
"lOffset": entity_start_position,
|
| 245 |
+
"rOffset": entity_end_position,
|
| 246 |
+
}
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
idx += len(subtree)
|
| 250 |
+
|
| 251 |
+
# Update the current character position
|
| 252 |
+
# We add the length of the original string + 1 (for the space)
|
| 253 |
+
else:
|
| 254 |
+
token, pos = subtree
|
| 255 |
+
# If it's not a named entity, we still need to update the character
|
| 256 |
+
# position
|
| 257 |
+
idx += 1
|
| 258 |
+
|
| 259 |
+
return entities
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def realign(word_ids, tokens, out_label_preds, softmax_scores, tokenizer, reverted_label_map):
|
| 263 |
+
preds_list, words_list, confidence_list = [], [], []
|
| 264 |
+
|
| 265 |
+
seen_word_ids = set()
|
| 266 |
+
for i, word_id in enumerate(word_ids):
|
| 267 |
+
if word_id is None or word_id in seen_word_ids:
|
| 268 |
+
continue # skip special tokens or repeated subwords
|
| 269 |
+
|
| 270 |
+
seen_word_ids.add(word_id)
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
preds_list.append(reverted_label_map[out_label_preds[i]])
|
| 274 |
+
confidence_list.append(max(softmax_scores[i]))
|
| 275 |
+
except Exception:
|
| 276 |
+
preds_list.append("O")
|
| 277 |
+
confidence_list.append(0.0)
|
| 278 |
+
|
| 279 |
+
words_list.append(tokens[word_id]) # original word list index
|
| 280 |
+
|
| 281 |
+
return words_list, preds_list, confidence_list
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def add_spaces_around_punctuation(text):
|
| 285 |
+
# Add a space before and after all punctuation
|
| 286 |
+
all_punctuation = string.punctuation + punctuation
|
| 287 |
+
return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def attach_comp_to_closest(entities):
|
| 291 |
+
# Define valid entity types that can receive a "comp.function" or "comp.name" attachment
|
| 292 |
+
valid_entity_types = {"org", "pers", "org.ent", "pers.ind"}
|
| 293 |
+
|
| 294 |
+
# Separate "comp.function" and "comp.name" entities from other entities
|
| 295 |
+
comp_entities = [ent for ent in entities if ent["type"].startswith("comp")]
|
| 296 |
+
other_entities = [ent for ent in entities if not ent["type"].startswith("comp")]
|
| 297 |
+
|
| 298 |
+
for comp_entity in comp_entities:
|
| 299 |
+
closest_entity = None
|
| 300 |
+
min_distance = float("inf")
|
| 301 |
+
|
| 302 |
+
# Find the closest non-"comp" entity that is valid for attaching
|
| 303 |
+
for other_entity in other_entities:
|
| 304 |
+
# Calculate distance between the comp entity and the other entity
|
| 305 |
+
if comp_entity["lOffset"] > other_entity["rOffset"]:
|
| 306 |
+
distance = comp_entity["lOffset"] - other_entity["rOffset"]
|
| 307 |
+
elif comp_entity["rOffset"] < other_entity["lOffset"]:
|
| 308 |
+
distance = other_entity["lOffset"] - comp_entity["rOffset"]
|
| 309 |
+
else:
|
| 310 |
+
distance = 0 # They overlap or touch
|
| 311 |
+
|
| 312 |
+
# Ensure the entity type is valid and check for minimal distance
|
| 313 |
+
if (
|
| 314 |
+
distance < min_distance
|
| 315 |
+
and other_entity["type"].split(".")[0] in valid_entity_types
|
| 316 |
+
):
|
| 317 |
+
min_distance = distance
|
| 318 |
+
closest_entity = other_entity
|
| 319 |
+
|
| 320 |
+
# Attach the "comp.function" or "comp.name" if a valid entity is found
|
| 321 |
+
if closest_entity:
|
| 322 |
+
suffix = comp_entity["type"].split(".")[
|
| 323 |
+
-1
|
| 324 |
+
] # Extract the suffix (e.g., 'name', 'function')
|
| 325 |
+
closest_entity[suffix] = comp_entity["surface"] # Attach the text
|
| 326 |
+
|
| 327 |
+
return other_entities
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def conflicting_context(comp_entity, target_entity):
|
| 331 |
+
"""
|
| 332 |
+
Determines if there is a conflict between the comp_entity and the target entity.
|
| 333 |
+
Prevents incorrect name and function attachments by using a rule-based approach.
|
| 334 |
+
"""
|
| 335 |
+
# Case 1: Check for correct function attachment to person or organization entities
|
| 336 |
+
if comp_entity["type"].startswith("comp.function"):
|
| 337 |
+
if not ("pers" in target_entity["type"] or "org" in target_entity["type"]):
|
| 338 |
+
return True # Conflict: Function should only attach to persons or organizations
|
| 339 |
+
|
| 340 |
+
# Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations)
|
| 341 |
+
if "loc" in target_entity["type"]:
|
| 342 |
+
return True # Conflict: comp.* entities should not attach to locations or similar types
|
| 343 |
+
|
| 344 |
+
return False # No conflict
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def extract_name_from_text(text, partial_name):
|
| 348 |
+
"""
|
| 349 |
+
Extracts the full name from the entity's text based on the partial name.
|
| 350 |
+
This function assumes that the full name starts with capitalized letters and does not
|
| 351 |
+
include any words that come after the partial name.
|
| 352 |
+
"""
|
| 353 |
+
# Split the text and partial name into words
|
| 354 |
+
words = text.split()
|
| 355 |
+
partial_words = partial_name.split()
|
| 356 |
+
|
| 357 |
+
if DEBUG:
|
| 358 |
+
print("text:", text)
|
| 359 |
+
if DEBUG:
|
| 360 |
+
print("partial_name:", partial_name)
|
| 361 |
+
|
| 362 |
+
# Find the position of the partial name in the word list
|
| 363 |
+
for i, word in enumerate(words):
|
| 364 |
+
if DEBUG:
|
| 365 |
+
print(words, "---", words[i: i + len(partial_words)])
|
| 366 |
+
if words[i: i + len(partial_words)] == partial_words:
|
| 367 |
+
# Initialize full name with the partial name
|
| 368 |
+
full_name = partial_words[:]
|
| 369 |
+
|
| 370 |
+
if DEBUG:
|
| 371 |
+
print("full_name:", full_name)
|
| 372 |
+
|
| 373 |
+
# Check previous words and only add capitalized words (skip lowercase words)
|
| 374 |
+
j = i - 1
|
| 375 |
+
while j >= 0 and words[j][0].isupper():
|
| 376 |
+
full_name.insert(0, words[j])
|
| 377 |
+
j -= 1
|
| 378 |
+
if DEBUG:
|
| 379 |
+
print("full_name:", full_name)
|
| 380 |
+
|
| 381 |
+
# Return only the full name up to the partial name (ignore words after the name)
|
| 382 |
+
return " ".join(full_name).strip() # Join the words to form the full name
|
| 383 |
+
|
| 384 |
+
# If not found, return the original text (as a fallback)
|
| 385 |
+
return text.strip()
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def repair_names_in_entities(entities):
|
| 389 |
+
"""
|
| 390 |
+
This function repairs the names in the entities by extracting the full name
|
| 391 |
+
from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached.
|
| 392 |
+
"""
|
| 393 |
+
for entity in entities:
|
| 394 |
+
if "name" in entity and "pers" in entity["type"]:
|
| 395 |
+
name = entity["name"]
|
| 396 |
+
text = entity["surface"]
|
| 397 |
+
|
| 398 |
+
# Check if the attached name is part of the entity's text
|
| 399 |
+
if name in text:
|
| 400 |
+
# Extract the full name from the text by splitting around the attached name
|
| 401 |
+
full_name = extract_name_from_text(entity["surface"], name)
|
| 402 |
+
entity["name"] = (
|
| 403 |
+
full_name # Replace the partial name with the full name
|
| 404 |
+
)
|
| 405 |
+
# if "name" not in entity:
|
| 406 |
+
# entity["name"] = entity["surface"]
|
| 407 |
+
|
| 408 |
+
return entities
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def clean_coarse_entities(entities):
|
| 412 |
+
"""
|
| 413 |
+
This function removes entities that are not useful for the NEL process.
|
| 414 |
+
"""
|
| 415 |
+
# Define a set of entity types that are considered useful for NEL
|
| 416 |
+
useful_types = {
|
| 417 |
+
"pers", # Person
|
| 418 |
+
"loc", # Location
|
| 419 |
+
"org", # Organization
|
| 420 |
+
"date", # Product
|
| 421 |
+
"time", # Time
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
# Filter out entities that are not in the useful_types set unless they are comp.* entities
|
| 425 |
+
cleaned_entities = [
|
| 426 |
+
entity
|
| 427 |
+
for entity in entities
|
| 428 |
+
if entity["type"] in useful_types or "comp" in entity["type"]
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
return cleaned_entities
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def postprocess_entities(entities):
|
| 435 |
+
# Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field
|
| 436 |
+
entity_map = {}
|
| 437 |
+
|
| 438 |
+
# Loop over the entities and prioritize the one with the most dots
|
| 439 |
+
for entity in entities:
|
| 440 |
+
entity_text = entity["surface"]
|
| 441 |
+
num_dots = entity["type"].count(".")
|
| 442 |
+
|
| 443 |
+
# If the entity text is new, or this entity has more dots, update the map
|
| 444 |
+
if (
|
| 445 |
+
entity_text not in entity_map
|
| 446 |
+
or entity_map[entity_text]["type"].count(".") < num_dots
|
| 447 |
+
):
|
| 448 |
+
entity_map[entity_text] = entity
|
| 449 |
+
|
| 450 |
+
# Collect the filtered entities from the map
|
| 451 |
+
filtered_entities = list(entity_map.values())
|
| 452 |
+
|
| 453 |
+
# Step 2: Attach "comp.function" entities to the closest other entities
|
| 454 |
+
filtered_entities = attach_comp_to_closest(filtered_entities)
|
| 455 |
+
if DEBUG:
|
| 456 |
+
print("After attach_comp_to_closest:", filtered_entities, "\n")
|
| 457 |
+
filtered_entities = repair_names_in_entities(filtered_entities)
|
| 458 |
+
if DEBUG:
|
| 459 |
+
print("After repair_names_in_entities:", filtered_entities, "\n")
|
| 460 |
+
|
| 461 |
+
# Step 3: Remove entities that are not useful for NEL
|
| 462 |
+
# filtered_entities = clean_coarse_entities(filtered_entities)
|
| 463 |
+
|
| 464 |
+
# filtered_entities = remove_blacklisted_entities(filtered_entities)
|
| 465 |
+
|
| 466 |
+
return filtered_entities
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def remove_included_entities(entities):
|
| 470 |
+
# Loop through entities and remove those whose text is included in another with the same label
|
| 471 |
+
final_entities = []
|
| 472 |
+
for i, entity in enumerate(entities):
|
| 473 |
+
is_included = False
|
| 474 |
+
for other_entity in entities:
|
| 475 |
+
if entity["surface"] != other_entity["surface"]:
|
| 476 |
+
if "comp" in other_entity["type"]:
|
| 477 |
+
# Check if entity's text is a substring of another entity's text
|
| 478 |
+
if entity["surface"] in other_entity["surface"]:
|
| 479 |
+
is_included = True
|
| 480 |
+
break
|
| 481 |
+
elif (
|
| 482 |
+
entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
|
| 483 |
+
or other_entity["type"].split(".")[0]
|
| 484 |
+
in entity["type"].split(".")[0]
|
| 485 |
+
):
|
| 486 |
+
if entity["surface"] in other_entity["surface"]:
|
| 487 |
+
is_included = True
|
| 488 |
+
if not is_included:
|
| 489 |
+
final_entities.append(entity)
|
| 490 |
+
return final_entities
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def refine_entities_with_coarse(all_entities, coarse_entities):
|
| 494 |
+
"""
|
| 495 |
+
Looks through all entities and refines them based on the coarse entities.
|
| 496 |
+
If a surface match is found in the coarse entities and the types match,
|
| 497 |
+
the entity's confidence_ner and type are updated based on the coarse entity.
|
| 498 |
+
"""
|
| 499 |
+
# Create a dictionary for coarse entities based on surface and type for quick lookup
|
| 500 |
+
coarse_lookup = {}
|
| 501 |
+
for coarse_entity in coarse_entities:
|
| 502 |
+
key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0])
|
| 503 |
+
coarse_lookup[key] = coarse_entity
|
| 504 |
+
|
| 505 |
+
# Iterate through all entities and compare with the coarse entities
|
| 506 |
+
for entity in all_entities:
|
| 507 |
+
key = (
|
| 508 |
+
entity["surface"],
|
| 509 |
+
entity["type"].split(".")[0],
|
| 510 |
+
) # Use the coarse type for comparison
|
| 511 |
+
|
| 512 |
+
if key in coarse_lookup:
|
| 513 |
+
coarse_entity = coarse_lookup[key]
|
| 514 |
+
# If a match is found, update the confidence_ner and type in the entity
|
| 515 |
+
if entity["confidence_ner"] < coarse_entity["confidence_ner"]:
|
| 516 |
+
entity["confidence_ner"] = coarse_entity["confidence_ner"]
|
| 517 |
+
entity["type"] = coarse_entity[
|
| 518 |
+
"type"
|
| 519 |
+
] # Update the type if the confidence is higher
|
| 520 |
+
|
| 521 |
+
# No need to append to refined_entities, we're modifying in place
|
| 522 |
+
for entity in all_entities:
|
| 523 |
+
entity["type"] = entity["type"].split(".")[0]
|
| 524 |
+
return all_entities
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def remove_trailing_stopwords(entities):
|
| 528 |
+
"""
|
| 529 |
+
This function removes stopwords and punctuation from both the beginning and end of each entity's text
|
| 530 |
+
and repairs the lOffset and rOffset accordingly.
|
| 531 |
+
"""
|
| 532 |
+
if DEBUG:
|
| 533 |
+
print(f"Initial entities in remove_trailing_stopwords: {len(entities)}")
|
| 534 |
+
new_entities = []
|
| 535 |
+
for entity in entities:
|
| 536 |
+
if "comp" not in entity["type"]:
|
| 537 |
+
entity_text = entity["surface"]
|
| 538 |
+
original_len = len(entity_text)
|
| 539 |
+
|
| 540 |
+
# Initial offsets
|
| 541 |
+
lOffset = entity.get("lOffset", 0)
|
| 542 |
+
rOffset = entity.get("rOffset", original_len)
|
| 543 |
+
|
| 544 |
+
# Remove stopwords and punctuation from the beginning
|
| 545 |
+
# print('----', entity_text)
|
| 546 |
+
if len(entity_text.split()) < 1:
|
| 547 |
+
continue
|
| 548 |
+
while entity_text and (
|
| 549 |
+
entity_text.split()[0].lower() in stop_words
|
| 550 |
+
or entity_text[0] in punctuation
|
| 551 |
+
):
|
| 552 |
+
if entity_text.split()[0].lower() in stop_words:
|
| 553 |
+
stopword_len = (
|
| 554 |
+
len(entity_text.split()[0]) + 1
|
| 555 |
+
) # Adjust length for stopword and following space
|
| 556 |
+
entity_text = entity_text[stopword_len:] # Remove leading stopword
|
| 557 |
+
lOffset += stopword_len # Adjust the left offset
|
| 558 |
+
if DEBUG:
|
| 559 |
+
print(
|
| 560 |
+
f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
|
| 561 |
+
)
|
| 562 |
+
elif entity_text[0] in punctuation:
|
| 563 |
+
entity_text = entity_text[1:] # Remove leading punctuation
|
| 564 |
+
lOffset += 1 # Adjust the left offset
|
| 565 |
+
if DEBUG:
|
| 566 |
+
print(
|
| 567 |
+
f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Remove stopwords and punctuation from the end
|
| 571 |
+
if len(entity_text.strip()) > 1:
|
| 572 |
+
while (
|
| 573 |
+
entity_text.strip().split()
|
| 574 |
+
and (
|
| 575 |
+
entity_text.strip().split()[-1].lower() in stop_words
|
| 576 |
+
or entity_text[-1] in punctuation
|
| 577 |
+
)
|
| 578 |
+
):
|
| 579 |
+
if entity_text.strip().split() and entity_text.strip().split()[-1].lower() in stop_words:
|
| 580 |
+
stopword_len = len(entity_text.strip().split()[-1]) + 1 # account for space
|
| 581 |
+
entity_text = entity_text[:-stopword_len]
|
| 582 |
+
rOffset -= stopword_len
|
| 583 |
+
if DEBUG:
|
| 584 |
+
print(
|
| 585 |
+
f"Removed trailing stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']})"
|
| 586 |
+
)
|
| 587 |
+
if entity_text and entity_text[-1] in punctuation:
|
| 588 |
+
entity_text = entity_text[:-1]
|
| 589 |
+
rOffset -= 1
|
| 590 |
+
if DEBUG:
|
| 591 |
+
print(
|
| 592 |
+
f"Removed trailing punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']})"
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Skip certain entities based on rules
|
| 596 |
+
if entity_text in string.punctuation:
|
| 597 |
+
if DEBUG:
|
| 598 |
+
print(f"Skipping entity: {entity_text}")
|
| 599 |
+
# entities.remove(entity)
|
| 600 |
+
continue
|
| 601 |
+
# check now if its in stopwords
|
| 602 |
+
if entity_text.lower() in stop_words:
|
| 603 |
+
if DEBUG:
|
| 604 |
+
print(f"Skipping entity: {entity_text}")
|
| 605 |
+
# entities.remove(entity)
|
| 606 |
+
continue
|
| 607 |
+
# check now if the entire entity is a list of stopwords:
|
| 608 |
+
if all([word.lower() in stop_words for word in entity_text.split()]):
|
| 609 |
+
if DEBUG:
|
| 610 |
+
print(f"Skipping entity: {entity_text}")
|
| 611 |
+
# entities.remove(entity)
|
| 612 |
+
continue
|
| 613 |
+
# Check if the entire entity is made up of stopwords characters
|
| 614 |
+
if all(
|
| 615 |
+
[char.lower() in stop_words for char in entity_text if char.isalpha()]
|
| 616 |
+
):
|
| 617 |
+
if DEBUG:
|
| 618 |
+
print(
|
| 619 |
+
f"Skipping entity: {entity_text} (all characters are stopwords)"
|
| 620 |
+
)
|
| 621 |
+
# entities.remove(entity)
|
| 622 |
+
continue
|
| 623 |
+
# check now if all entity is in a list of punctuation
|
| 624 |
+
if all([word in string.punctuation for word in entity_text.split()]):
|
| 625 |
+
if DEBUG:
|
| 626 |
+
print(
|
| 627 |
+
f"Skipping entity: {entity_text} (all characters are punctuation)"
|
| 628 |
+
)
|
| 629 |
+
# entities.remove(entity)
|
| 630 |
+
continue
|
| 631 |
+
if all(
|
| 632 |
+
[
|
| 633 |
+
char.lower() in string.punctuation
|
| 634 |
+
for char in entity_text
|
| 635 |
+
if char.isalpha()
|
| 636 |
+
]
|
| 637 |
+
):
|
| 638 |
+
if DEBUG:
|
| 639 |
+
print(
|
| 640 |
+
f"Skipping entity: {entity_text} (all characters are punctuation)"
|
| 641 |
+
)
|
| 642 |
+
# entities.remove(entity)
|
| 643 |
+
continue
|
| 644 |
+
|
| 645 |
+
# if it's a number and "time" no in it, then continue
|
| 646 |
+
if entity_text.isdigit() and "time" not in entity["type"]:
|
| 647 |
+
if DEBUG:
|
| 648 |
+
print(f"Skipping entity: {entity_text}")
|
| 649 |
+
# entities.remove(entity)
|
| 650 |
+
continue
|
| 651 |
+
|
| 652 |
+
if entity_text.startswith(" "):
|
| 653 |
+
entity_text = entity_text[1:]
|
| 654 |
+
# update lOffset, rOffset
|
| 655 |
+
lOffset += 1
|
| 656 |
+
if entity_text.endswith(" "):
|
| 657 |
+
entity_text = entity_text[:-1]
|
| 658 |
+
# update lOffset, rOffset
|
| 659 |
+
rOffset -= 1
|
| 660 |
+
|
| 661 |
+
# Update the entity surface and offsets
|
| 662 |
+
entity["surface"] = entity_text
|
| 663 |
+
entity["lOffset"] = lOffset
|
| 664 |
+
entity["rOffset"] = rOffset
|
| 665 |
+
|
| 666 |
+
# Remove the entity if the surface is empty after cleaning
|
| 667 |
+
if len(entity["surface"].strip()) == 0:
|
| 668 |
+
if DEBUG:
|
| 669 |
+
print(f"Deleted entity: {entity['surface']}")
|
| 670 |
+
# entities.remove(entity)
|
| 671 |
+
else:
|
| 672 |
+
new_entities.append(entity)
|
| 673 |
+
else:
|
| 674 |
+
new_entities.append(entity)
|
| 675 |
+
if DEBUG:
|
| 676 |
+
print(f"Remained entities in remove_trailing_stopwords: {len(new_entities)}")
|
| 677 |
+
return new_entities
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class ExtendedMultitaskTimeModelForTokenClassificationPipeline(Pipeline):
|
| 681 |
+
|
| 682 |
+
def _sanitize_parameters(self, **kwargs):
|
| 683 |
+
preprocess_kwargs = {}
|
| 684 |
+
if "text" in kwargs:
|
| 685 |
+
preprocess_kwargs["text"] = kwargs["text"]
|
| 686 |
+
if "tokens" in kwargs:
|
| 687 |
+
preprocess_kwargs["tokens"] = kwargs["tokens"]
|
| 688 |
+
self.label_map = self.model.config.label_map
|
| 689 |
+
self.id2label = {
|
| 690 |
+
task: {id_: label for label, id_ in labels.items()}
|
| 691 |
+
for task, labels in self.label_map.items()
|
| 692 |
+
}
|
| 693 |
+
return preprocess_kwargs, {}, {}
|
| 694 |
+
|
| 695 |
+
def preprocess(self, text, **kwargs):
|
| 696 |
+
|
| 697 |
+
tokens = kwargs["tokens"]
|
| 698 |
+
tokenized_inputs = self.tokenizer(
|
| 699 |
+
tokens, # a list of strings
|
| 700 |
+
is_split_into_words=True,
|
| 701 |
+
padding="max_length",
|
| 702 |
+
truncation=True,
|
| 703 |
+
max_length=512,
|
| 704 |
+
)
|
| 705 |
+
word_ids = tokenized_inputs.word_ids()
|
| 706 |
+
|
| 707 |
+
return tokenized_inputs, word_ids, text, tokens
|
| 708 |
+
|
| 709 |
+
def _forward(self, inputs):
|
| 710 |
+
inputs, word_ids, text, tokens = inputs
|
| 711 |
+
|
| 712 |
+
input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
|
| 713 |
+
self.model.device
|
| 714 |
+
)
|
| 715 |
+
attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
|
| 716 |
+
self.model.device
|
| 717 |
+
)
|
| 718 |
+
with torch.no_grad():
|
| 719 |
+
outputs = self.model(input_ids, attention_mask)
|
| 720 |
+
return outputs, word_ids, text, tokens
|
| 721 |
+
|
| 722 |
+
def is_within(self, entity1, entity2):
|
| 723 |
+
"""Check if entity1 is fully within the bounds of entity2."""
|
| 724 |
+
return (
|
| 725 |
+
entity1["lOffset"] >= entity2["lOffset"]
|
| 726 |
+
and entity1["rOffset"] <= entity2["rOffset"]
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
def postprocess(self, outputs, **kwargs):
|
| 730 |
+
"""
|
| 731 |
+
Postprocess the outputs of the model
|
| 732 |
+
:param outputs:
|
| 733 |
+
:param kwargs:
|
| 734 |
+
:return:
|
| 735 |
+
"""
|
| 736 |
+
tokens_result, word_ids, text, tokens = outputs
|
| 737 |
+
|
| 738 |
+
predictions = {}
|
| 739 |
+
confidence_scores = {}
|
| 740 |
+
for task, logits in tokens_result.logits.items():
|
| 741 |
+
predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
|
| 742 |
+
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
|
| 743 |
+
|
| 744 |
+
entities = {}
|
| 745 |
+
for task in predictions.keys():
|
| 746 |
+
words_list, preds_list, confidence_list = realign(
|
| 747 |
+
word_ids,
|
| 748 |
+
tokens,
|
| 749 |
+
predictions[task],
|
| 750 |
+
confidence_scores[task],
|
| 751 |
+
self.tokenizer,
|
| 752 |
+
self.id2label[task],
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
| 756 |
+
|
| 757 |
+
# add titles to comp entities
|
| 758 |
+
# from pprint import pprint
|
| 759 |
+
|
| 760 |
+
# print("Before:")
|
| 761 |
+
# pprint(entities)
|
| 762 |
+
|
| 763 |
+
all_entities = []
|
| 764 |
+
coarse_entities = []
|
| 765 |
+
for key in entities:
|
| 766 |
+
if key in ["NE-COARSE-LIT"]:
|
| 767 |
+
coarse_entities = entities[key]
|
| 768 |
+
all_entities.extend(entities[key])
|
| 769 |
+
|
| 770 |
+
if DEBUG:
|
| 771 |
+
print(all_entities)
|
| 772 |
+
# print("After remove_included_entities:")
|
| 773 |
+
all_entities = remove_included_entities(all_entities)
|
| 774 |
+
if DEBUG:
|
| 775 |
+
print("After remove_included_entities:", all_entities)
|
| 776 |
+
all_entities = remove_trailing_stopwords(all_entities)
|
| 777 |
+
if DEBUG:
|
| 778 |
+
print("After remove_trailing_stopwords:", all_entities)
|
| 779 |
+
all_entities = postprocess_entities(all_entities)
|
| 780 |
+
if DEBUG:
|
| 781 |
+
print("After postprocess_entities:", all_entities)
|
| 782 |
+
all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
|
| 783 |
+
if DEBUG:
|
| 784 |
+
print("After refine_entities_with_coarse:", all_entities)
|
| 785 |
+
# print("After attach_comp_to_closest:")
|
| 786 |
+
# pprint(all_entities)
|
| 787 |
+
# print("\n")
|
| 788 |
+
return all_entities
|
label_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"NE-COARSE-LIT": {"I-pers": 0, "I-prod": 1, "B-prod": 2, "B-loc": 3, "I-time": 4, "B-pers": 5, "B-org": 6, "B-time": 7, "I-loc": 8, "O": 9, "I-org": 10}, "NE-FINE-COMP": {"I-comp.title": 0, "B-comp.title": 1, "I-comp.function": 2, "I-comp.name": 3, "B-comp.function": 4, "O": 5, "B-comp.name": 6}}
|
modeling_stacked.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
|
| 5 |
+
from torch.nn import CrossEntropyLoss
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
+
import logging, json, os
|
| 8 |
+
|
| 9 |
+
from .configuration_stacked import ImpressoConfig
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_info(label_map):
|
| 15 |
+
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
|
| 16 |
+
return num_token_labels_dict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ExtendedMultitaskTimeModelForTokenClassification(PreTrainedModel):
|
| 20 |
+
config_class = ImpressoConfig
|
| 21 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 22 |
+
|
| 23 |
+
def __init__(self, config, temporal_fusion_strategy="baseline", num_years=327):
|
| 24 |
+
super().__init__(config)
|
| 25 |
+
self.num_token_labels_dict = get_info(config.label_map)
|
| 26 |
+
self.config = config
|
| 27 |
+
self.temporal_fusion_strategy = temporal_fusion_strategy
|
| 28 |
+
self.model = AutoModel.from_pretrained(
|
| 29 |
+
config.pretrained_config["_name_or_path"], config=config.pretrained_config
|
| 30 |
+
)
|
| 31 |
+
self.model.config.use_cache = False
|
| 32 |
+
self.model.config.pretraining_tp = 1
|
| 33 |
+
self.num_years = num_years
|
| 34 |
+
|
| 35 |
+
classifier_dropout = getattr(config, "classifier_dropout", 0.1) or config.hidden_dropout_prob
|
| 36 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 37 |
+
|
| 38 |
+
self.temporal_fusion = TemporalFusion(config.hidden_size, strategy=self.temporal_fusion_strategy,
|
| 39 |
+
num_years=num_years)
|
| 40 |
+
|
| 41 |
+
# Additional transformer layers
|
| 42 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
| 43 |
+
nn.TransformerEncoderLayer(
|
| 44 |
+
d_model=config.hidden_size, nhead=config.num_attention_heads
|
| 45 |
+
),
|
| 46 |
+
num_layers=2,
|
| 47 |
+
)
|
| 48 |
+
self.token_classifiers = nn.ModuleDict({
|
| 49 |
+
task: nn.Linear(config.hidden_size, num_labels)
|
| 50 |
+
for task, num_labels in self.num_token_labels_dict.items()
|
| 51 |
+
})
|
| 52 |
+
|
| 53 |
+
self.post_init()
|
| 54 |
+
|
| 55 |
+
def forward(
|
| 56 |
+
self,
|
| 57 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 59 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 60 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 61 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 62 |
+
labels: Optional[torch.Tensor] = None,
|
| 63 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 64 |
+
token_labels: Optional[dict] = None,
|
| 65 |
+
date_indices: Optional[torch.Tensor] = None,
|
| 66 |
+
year_index: Optional[torch.Tensor] = None,
|
| 67 |
+
decade_index: Optional[torch.Tensor] = None,
|
| 68 |
+
century_index: Optional[torch.Tensor] = None,
|
| 69 |
+
output_attentions: Optional[bool] = None,
|
| 70 |
+
output_hidden_states: Optional[bool] = None,
|
| 71 |
+
return_dict: Optional[bool] = None,
|
| 72 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 73 |
+
|
| 74 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 75 |
+
|
| 76 |
+
if inputs_embeds is None:
|
| 77 |
+
inputs_embeds = self.model.embeddings(input_ids)
|
| 78 |
+
|
| 79 |
+
# Early cross-attention fusion
|
| 80 |
+
if self.temporal_fusion_strategy == "early-cross-attention":
|
| 81 |
+
year_emb = self.temporal_fusion.compute_time_embedding(year_index) # (B, H)
|
| 82 |
+
inputs_embeds = self.temporal_fusion.cross_attn(inputs_embeds, year_emb)
|
| 83 |
+
|
| 84 |
+
bert_kwargs = {
|
| 85 |
+
"inputs_embeds": inputs_embeds if self.temporal_fusion_strategy == "early-cross-attention" else None,
|
| 86 |
+
"input_ids": input_ids if self.temporal_fusion_strategy != "early-cross-attention" else None,
|
| 87 |
+
"attention_mask": attention_mask,
|
| 88 |
+
"token_type_ids": token_type_ids,
|
| 89 |
+
"position_ids": position_ids,
|
| 90 |
+
"head_mask": head_mask,
|
| 91 |
+
"output_attentions": output_attentions,
|
| 92 |
+
"output_hidden_states": output_hidden_states,
|
| 93 |
+
"return_dict": return_dict,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
if any(keyword in self.config.name_or_path.lower() for keyword in ["llama", "deberta"]):
|
| 97 |
+
bert_kwargs.pop("token_type_ids", None)
|
| 98 |
+
bert_kwargs.pop("head_mask", None)
|
| 99 |
+
|
| 100 |
+
outputs = self.model(**bert_kwargs)
|
| 101 |
+
token_output = self.dropout(outputs[0]) # (B, T, H)
|
| 102 |
+
hidden_states = list(outputs.hidden_states) if output_hidden_states else None
|
| 103 |
+
|
| 104 |
+
# Pass through additional transformer layers
|
| 105 |
+
token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
|
| 106 |
+
0, 1
|
| 107 |
+
)
|
| 108 |
+
# Apply fusion after transformer if needed
|
| 109 |
+
if self.temporal_fusion_strategy not in ["baseline", "early-cross-attention"]:
|
| 110 |
+
token_output = self.temporal_fusion(token_output, year_index)
|
| 111 |
+
if output_hidden_states:
|
| 112 |
+
hidden_states.append(token_output) # add the final fused state
|
| 113 |
+
|
| 114 |
+
task_logits = {}
|
| 115 |
+
total_loss = 0
|
| 116 |
+
for task, classifier in self.token_classifiers.items():
|
| 117 |
+
logits = classifier(token_output)
|
| 118 |
+
task_logits[task] = logits
|
| 119 |
+
if token_labels and task in token_labels:
|
| 120 |
+
loss_fct = CrossEntropyLoss()
|
| 121 |
+
loss = loss_fct(
|
| 122 |
+
logits.view(-1, self.num_token_labels_dict[task]),
|
| 123 |
+
token_labels[task].view(-1),
|
| 124 |
+
)
|
| 125 |
+
total_loss += loss
|
| 126 |
+
|
| 127 |
+
if not return_dict:
|
| 128 |
+
output = (task_logits,) + outputs[2:]
|
| 129 |
+
return ((total_loss,) + output) if total_loss != 0 else output
|
| 130 |
+
|
| 131 |
+
return TokenClassifierOutput(
|
| 132 |
+
loss=total_loss,
|
| 133 |
+
logits=task_logits,
|
| 134 |
+
hidden_states=tuple(hidden_states) if hidden_states is not None else None,
|
| 135 |
+
attentions=outputs.attentions if output_attentions else None,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class TemporalFusion(nn.Module):
|
| 140 |
+
def __init__(self, hidden_size, strategy="add", num_years=327, min_year=1700):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.strategy = strategy
|
| 143 |
+
self.hidden_size = hidden_size
|
| 144 |
+
self.min_year = min_year
|
| 145 |
+
self.max_year = min_year + num_years - 1
|
| 146 |
+
|
| 147 |
+
self.year_emb = nn.Embedding(num_years, hidden_size)
|
| 148 |
+
|
| 149 |
+
if strategy == "concat":
|
| 150 |
+
self.concat_proj = nn.Linear(hidden_size * 2, hidden_size)
|
| 151 |
+
elif strategy == "film":
|
| 152 |
+
self.film_gamma = nn.Linear(hidden_size, hidden_size)
|
| 153 |
+
self.film_beta = nn.Linear(hidden_size, hidden_size)
|
| 154 |
+
elif strategy == "adapter":
|
| 155 |
+
self.adapter = nn.Sequential(
|
| 156 |
+
nn.Linear(hidden_size, hidden_size),
|
| 157 |
+
nn.ReLU(),
|
| 158 |
+
nn.Linear(hidden_size, hidden_size),
|
| 159 |
+
)
|
| 160 |
+
elif strategy == "relative":
|
| 161 |
+
self.relative_encoder = nn.Sequential(
|
| 162 |
+
nn.Linear(hidden_size, hidden_size),
|
| 163 |
+
nn.SiLU(),
|
| 164 |
+
nn.LayerNorm(hidden_size),
|
| 165 |
+
)
|
| 166 |
+
self.film_gamma = nn.Linear(hidden_size, hidden_size)
|
| 167 |
+
self.film_beta = nn.Linear(hidden_size, hidden_size)
|
| 168 |
+
elif strategy == "multiscale":
|
| 169 |
+
self.decade_emb = nn.Embedding(1000, hidden_size)
|
| 170 |
+
self.century_emb = nn.Embedding(100, hidden_size)
|
| 171 |
+
elif strategy in ["early-cross-attention", "late-cross-attention"]:
|
| 172 |
+
self.year_encoder = nn.Sequential(
|
| 173 |
+
nn.Linear(hidden_size, hidden_size),
|
| 174 |
+
nn.SiLU()
|
| 175 |
+
)
|
| 176 |
+
self.cross_attn = TemporalCrossAttention(hidden_size)
|
| 177 |
+
|
| 178 |
+
def compute_time_embedding(self, year_index):
|
| 179 |
+
if self.strategy in ["early-cross-attention", "late-cross-attention"]:
|
| 180 |
+
return self.year_encoder(self.year_emb(year_index))
|
| 181 |
+
elif self.strategy == "multiscale":
|
| 182 |
+
year_index = year_index.long()
|
| 183 |
+
year = year_index + self.min_year
|
| 184 |
+
decade = (year // 10).long()
|
| 185 |
+
century = (year // 100).long()
|
| 186 |
+
return (
|
| 187 |
+
self.year_emb(year_index) +
|
| 188 |
+
self.decade_emb(decade) +
|
| 189 |
+
self.century_emb(century)
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
return self.year_emb(year_index)
|
| 193 |
+
|
| 194 |
+
def forward(self, token_output, year_index):
|
| 195 |
+
B, T, H = token_output.size()
|
| 196 |
+
|
| 197 |
+
if self.strategy == "baseline":
|
| 198 |
+
return token_output
|
| 199 |
+
|
| 200 |
+
year_emb = self.compute_time_embedding(year_index)
|
| 201 |
+
|
| 202 |
+
if self.strategy == "concat":
|
| 203 |
+
expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1)
|
| 204 |
+
fused = torch.cat([token_output, expanded_year], dim=-1)
|
| 205 |
+
return self.concat_proj(fused)
|
| 206 |
+
|
| 207 |
+
elif self.strategy == "film":
|
| 208 |
+
gamma = self.film_gamma(year_emb).unsqueeze(1)
|
| 209 |
+
beta = self.film_beta(year_emb).unsqueeze(1)
|
| 210 |
+
return gamma * token_output + beta
|
| 211 |
+
|
| 212 |
+
elif self.strategy == "adapter":
|
| 213 |
+
return token_output + self.adapter(year_emb).unsqueeze(1)
|
| 214 |
+
|
| 215 |
+
elif self.strategy == "add":
|
| 216 |
+
expanded_year = year_emb.unsqueeze(1).repeat(1, T, 1)
|
| 217 |
+
return token_output + expanded_year
|
| 218 |
+
|
| 219 |
+
elif self.strategy == "relative":
|
| 220 |
+
encoded = self.relative_encoder(year_emb)
|
| 221 |
+
gamma = self.film_gamma(encoded).unsqueeze(1)
|
| 222 |
+
beta = self.film_beta(encoded).unsqueeze(1)
|
| 223 |
+
return gamma * token_output + beta
|
| 224 |
+
|
| 225 |
+
elif self.strategy == "multiscale":
|
| 226 |
+
expanded_year = year_emb.unsqueeze(1).expand(-1, T, -1)
|
| 227 |
+
return token_output + expanded_year
|
| 228 |
+
|
| 229 |
+
elif self.strategy == "late-cross-attention":
|
| 230 |
+
return self.cross_attn(token_output, year_emb)
|
| 231 |
+
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(f"Unknown fusion strategy: {self.strategy}")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class TemporalCrossAttention(nn.Module):
|
| 237 |
+
def __init__(self, hidden_size, num_heads=4):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)
|
| 240 |
+
|
| 241 |
+
def forward(self, token_output, time_embedding):
|
| 242 |
+
# token_output: (B, T, H), time_embedding: (B, H)
|
| 243 |
+
time_as_seq = time_embedding.unsqueeze(1) # (B, 1, H)
|
| 244 |
+
attn_output, _ = self.attn(token_output, time_as_seq, time_as_seq)
|
| 245 |
+
return token_output + attn_output
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4693904ab70f0ef7c0249db6c23b0b1e3b2760629d8dafa9010f5ec9feb7de39
|
| 3 |
+
size 168604214
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
test.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import necessary modules from the transformers library
|
| 2 |
+
from transformers import pipeline
|
| 3 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
| 4 |
+
|
| 5 |
+
# Define the model name to be used for token classification, we use the Impresso NER
|
| 6 |
+
# that can be found at "https://huggingface.co/impresso-project/ner-stacked-bert-multilingual"
|
| 7 |
+
MODEL_NAME = "impresso-project/ner-stacked-bert-multilingual"
|
| 8 |
+
|
| 9 |
+
# Load the tokenizer corresponding to the specified model name
|
| 10 |
+
ner_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 11 |
+
|
| 12 |
+
ner_pipeline = pipeline(
|
| 13 |
+
"generic-ner",
|
| 14 |
+
model=MODEL_NAME,
|
| 15 |
+
tokenizer=ner_tokenizer,
|
| 16 |
+
trust_remote_code=True,
|
| 17 |
+
device="cpu",
|
| 18 |
+
)
|
| 19 |
+
sentences = [
|
| 20 |
+
"""In the year 1789, King Louis XVI, ruler of France, convened the Estates-General at the Palace of Versailles,
|
| 21 |
+
where Marie Antoinette, the Queen of France, alongside Maximilien Robespierre, a leading member of the National Assembly,
|
| 22 |
+
debated with Jean-Jacques Rousseau, the famous philosopher, and Charles de Talleyrand, the Bishop of Autun,
|
| 23 |
+
regarding the future of the French monarchy. At the same time, across the Atlantic in Philadelphia,
|
| 24 |
+
George Washington, the first President of the United States, and Thomas Jefferson, the nation's Secretary of State,
|
| 25 |
+
were drafting policies for the newly established American government following the signing of the Constitution."""
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
print(sentences[0])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Helper function to print entities one per row
|
| 32 |
+
def print_nicely(entities):
|
| 33 |
+
for entity in entities:
|
| 34 |
+
print(
|
| 35 |
+
f"Entity: {entity['entity']} | Confidence: {entity['score']:.2f}% | Text: {entity['word'].strip()} | Start: {entity['start']} | End: {entity['end']}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Visualize stacked entities for each sentence
|
| 40 |
+
for sentence in sentences:
|
| 41 |
+
results = ner_pipeline(sentence)
|
| 42 |
+
|
| 43 |
+
# Extract coarse and fine entities
|
| 44 |
+
for key in results.keys():
|
| 45 |
+
# Visualize the coarse entities
|
| 46 |
+
print_nicely(results[key])
|
test_ner.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline, AutoTokenizer
|
| 2 |
+
import bz2, json
|
| 3 |
+
from pprint import pprint
|
| 4 |
+
|
| 5 |
+
MODEL_NAME = "impresso-project/ner-stacked-bert-multilingual-light"
|
| 6 |
+
|
| 7 |
+
# Load the tokenizer and model using the pipeline
|
| 8 |
+
ner_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 9 |
+
|
| 10 |
+
ner_pipeline = pipeline(
|
| 11 |
+
"generic-ner",
|
| 12 |
+
model=MODEL_NAME,
|
| 13 |
+
tokenizer=ner_tokenizer,
|
| 14 |
+
trust_remote_code=True,
|
| 15 |
+
device="cpu",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def process_archive(lingproc_path):
|
| 19 |
+
"""
|
| 20 |
+
Processes paired NER and full-text archives to extract full text and sentence offsets.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
ner_path (str): Path to the NER .jsonl.bz2 archive.
|
| 24 |
+
fulltext_path (str): Path to the full-text .jsonl.bz2 archive.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
List of tuples: (doc_id, full_text, sentence_offsets)
|
| 28 |
+
"""
|
| 29 |
+
results = []
|
| 30 |
+
|
| 31 |
+
with bz2.open(lingproc_path, mode='rt', encoding='utf-8') as f:
|
| 32 |
+
for line in f:
|
| 33 |
+
data = json.loads(line)
|
| 34 |
+
doc_id = data.get("id")
|
| 35 |
+
|
| 36 |
+
# Reconstruct the full text from all tokens using their offsets
|
| 37 |
+
offset_token_map = {}
|
| 38 |
+
for sent in data.get("sents", []):
|
| 39 |
+
for token in sent.get("tok", []):
|
| 40 |
+
offset = token["o"]
|
| 41 |
+
text = token["t"]
|
| 42 |
+
offset_token_map[offset] = text
|
| 43 |
+
|
| 44 |
+
# Rebuild full text from sorted offsets
|
| 45 |
+
full_text_parts = []
|
| 46 |
+
sorted_offsets = sorted(offset_token_map.keys())
|
| 47 |
+
last_end = 0
|
| 48 |
+
for offset in sorted_offsets:
|
| 49 |
+
token = offset_token_map[offset]
|
| 50 |
+
if offset > last_end:
|
| 51 |
+
full_text_parts.append(" " * (offset - last_end))
|
| 52 |
+
full_text_parts.append(token)
|
| 53 |
+
last_end = offset + len(token)
|
| 54 |
+
full_text = "".join(full_text_parts).strip()
|
| 55 |
+
|
| 56 |
+
# assert new_full_text == full_text, f"Full text mismatch for doc_id {doc_id}. Expected: {full_text}, Got: {new_full_text}"
|
| 57 |
+
|
| 58 |
+
sentences = []
|
| 59 |
+
for sent in data.get("sents", []):
|
| 60 |
+
tokens = sent.get("tok", [])
|
| 61 |
+
if not tokens:
|
| 62 |
+
continue
|
| 63 |
+
start = tokens[0]["o"]
|
| 64 |
+
end = tokens[-1]["o"] + len(tokens[-1]["t"])
|
| 65 |
+
newtokens = [{"t": token["t"], "o": token["o"], "l": len(token["t"])} for token in tokens]
|
| 66 |
+
sentences.append({"start": start, "end": end, "tokens": newtokens})
|
| 67 |
+
results.append((doc_id, full_text, sentences))
|
| 68 |
+
|
| 69 |
+
return results
|
| 70 |
+
|
| 71 |
+
processed_cis = process_archive("../../data/lematin-1885.jsonl.bz2")
|
| 72 |
+
|
| 73 |
+
for ci in processed_cis:
|
| 74 |
+
doc_id, full_text, offsets = ci
|
| 75 |
+
print(f"Document ID: {doc_id}")
|
| 76 |
+
# print(f"Full Text: {full_text}")
|
| 77 |
+
# print("Sentences:")
|
| 78 |
+
for sentence in offsets:
|
| 79 |
+
start = sentence["start"]
|
| 80 |
+
end = sentence["end"]
|
| 81 |
+
tokens = sentence["tokens"]
|
| 82 |
+
sentence_text = full_text[start:end]
|
| 83 |
+
tokens_texts = [full_text[token["o"]:token["o"] + len(token["t"])] for token in tokens]
|
| 84 |
+
# print(sentence_text)
|
| 85 |
+
|
| 86 |
+
entities = ner_pipeline(sentence_text, tokens=tokens_texts)
|
| 87 |
+
|
| 88 |
+
for entity in entities:
|
| 89 |
+
abs_start = sentence["start"] + entity["lOffset"]
|
| 90 |
+
abs_end = sentence["start"] + entity["rOffset"]
|
| 91 |
+
entity_text = full_text[abs_start:abs_end]
|
| 92 |
+
entity_surface = entity["surface"]
|
| 93 |
+
assert entity_text == entity_surface, f"Entity text mismatch: {entity_text} != {entity_surface}"
|
| 94 |
+
print(f"{doc_id}: {entity_text} -- surface: {entity_surface} -- {entity['type']} -- {abs_start} - {abs_end}")
|
| 95 |
+
# pprint(entities)
|
| 96 |
+
|
| 97 |
+
# print(f" Sentence: {sentence_text} (Start: {start}, End: {end})")
|
| 98 |
+
# for token in tokens:
|
| 99 |
+
# token_text = token["t"]
|
| 100 |
+
# token_offset = token["o"]
|
| 101 |
+
# token_label = token["l"]
|
| 102 |
+
# print(f" Token: {token_text} (Offset: {token_offset}, Label: {token_label})")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# entities = ner_pipeline(sentence)
|
| 106 |
+
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"4": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": false,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "[MASK]",
|
| 50 |
+
"max_len": 512,
|
| 51 |
+
"model_max_length": 512,
|
| 52 |
+
"never_split": null,
|
| 53 |
+
"pad_token": "[PAD]",
|
| 54 |
+
"sep_token": "[SEP]",
|
| 55 |
+
"strip_accents": false,
|
| 56 |
+
"tokenize_chinese_chars": true,
|
| 57 |
+
"tokenizer_class": "BertTokenizer",
|
| 58 |
+
"unk_token": "[UNK]"
|
| 59 |
+
}
|
training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc92dca5d693d80c40bfa708d0ee9551d1f85b832c57710b3edfc72dc86707e1
|
| 3 |
+
size 2104
|
vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|