Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		Tom Aarsen
		
	commited on
		
		
					Commit 
							
							·
						
						24e3585
	
1
								Parent(s):
							
							4b92f32
								
Heavily simplify app, rely on gliner on PyPI
Browse files- GLiNER/README.md +0 -74
 - GLiNER/model.py +0 -412
 - GLiNER/modules/base.py +0 -150
 - GLiNER/modules/data_proc.py +0 -73
 - GLiNER/modules/evaluator.py +0 -152
 - GLiNER/modules/layers.py +0 -28
 - GLiNER/modules/run_evaluation.py +0 -188
 - GLiNER/modules/span_rep.py +0 -326
 - GLiNER/modules/token_rep.py +0 -54
 - GLiNER/requirements.txt +0 -6
 - GLiNER/save_load.py +0 -20
 - GLiNER/train.py +0 -131
 - app.py +5 -10
 - requirements.txt +1 -6
 
    	
        GLiNER/README.md
    DELETED
    
    | 
         @@ -1,74 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            # Model Card for GLiNER-base
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios.
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            ## Links
         
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            * Paper: https://arxiv.org/abs/2311.08526
         
     | 
| 8 | 
         
            -
            * Repository: https://github.com/urchade/GLiNER
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
            ## Installation
         
     | 
| 11 | 
         
            -
            To use this model, you must download the GLiNER repository and install its dependencies:
         
     | 
| 12 | 
         
            -
            ```
         
     | 
| 13 | 
         
            -
            !git clone https://github.com/urchade/GLiNER.git
         
     | 
| 14 | 
         
            -
            %cd GLiNER
         
     | 
| 15 | 
         
            -
            !pip install -r requirements.txt
         
     | 
| 16 | 
         
            -
            ```
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
            ## Usage
         
     | 
| 19 | 
         
            -
            Once you've downloaded the GLiNER repository, you can import the GLiNER class from the `model` file. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
         
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
            ```python
         
     | 
| 22 | 
         
            -
            from model import GLiNER
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
            model = GLiNER.from_pretrained("urchade/gliner_base")
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
            text = """
         
     | 
| 27 | 
         
            -
            Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
         
     | 
| 28 | 
         
            -
            """
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
            labels = ["person", "award", "date", "competitions", "teams"]
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
            entities = model.predict_entities(text, labels)
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
            for entity in entities:
         
     | 
| 35 | 
         
            -
                print(entity["text"], "=>", entity["label"])
         
     | 
| 36 | 
         
            -
            ```
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
            ```
         
     | 
| 39 | 
         
            -
            Cristiano Ronaldo dos Santos Aveiro => person
         
     | 
| 40 | 
         
            -
            5 February 1985 => date
         
     | 
| 41 | 
         
            -
            Al Nassr => teams
         
     | 
| 42 | 
         
            -
            Portugal national team => teams
         
     | 
| 43 | 
         
            -
            Ballon d'Or => award
         
     | 
| 44 | 
         
            -
            UEFA Men's Player of the Year Awards => award
         
     | 
| 45 | 
         
            -
            European Golden Shoes => award
         
     | 
| 46 | 
         
            -
            UEFA Champions Leagues => competitions
         
     | 
| 47 | 
         
            -
            UEFA European Championship => competitions
         
     | 
| 48 | 
         
            -
            UEFA Nations League => competitions
         
     | 
| 49 | 
         
            -
            Champions League => competitions
         
     | 
| 50 | 
         
            -
            European Championship => competitions
         
     | 
| 51 | 
         
            -
            ```
         
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
            ## Named Entity Recognition benchmark result
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
            
         
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
            ## Model Authors
         
     | 
| 58 | 
         
            -
            The model authors are:
         
     | 
| 59 | 
         
            -
            * [Urchade Zaratiana](https://huggingface.co/urchade)
         
     | 
| 60 | 
         
            -
            * Nadi Tomeh
         
     | 
| 61 | 
         
            -
            * Pierre Holat
         
     | 
| 62 | 
         
            -
            * Thierry Charnois
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
            ## Citation
         
     | 
| 65 | 
         
            -
            ```bibtex
         
     | 
| 66 | 
         
            -
            @misc{zaratiana2023gliner,
         
     | 
| 67 | 
         
            -
                  title={GLiNER: Generalist Model for Named Entity Recognition using Bidirectional Transformer}, 
         
     | 
| 68 | 
         
            -
                  author={Urchade Zaratiana and Nadi Tomeh and Pierre Holat and Thierry Charnois},
         
     | 
| 69 | 
         
            -
                  year={2023},
         
     | 
| 70 | 
         
            -
                  eprint={2311.08526},
         
     | 
| 71 | 
         
            -
                  archivePrefix={arXiv},
         
     | 
| 72 | 
         
            -
                  primaryClass={cs.CL}
         
     | 
| 73 | 
         
            -
            }
         
     | 
| 74 | 
         
            -
            ```
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/model.py
    DELETED
    
    | 
         @@ -1,412 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import argparse
         
     | 
| 2 | 
         
            -
            import json
         
     | 
| 3 | 
         
            -
            from pathlib import Path
         
     | 
| 4 | 
         
            -
            import re
         
     | 
| 5 | 
         
            -
            from typing import Dict, Optional, Union
         
     | 
| 6 | 
         
            -
            import torch
         
     | 
| 7 | 
         
            -
            import torch.nn.functional as F
         
     | 
| 8 | 
         
            -
            from modules.layers import LstmSeq2SeqEncoder
         
     | 
| 9 | 
         
            -
            from modules.base import InstructBase
         
     | 
| 10 | 
         
            -
            from modules.evaluator import Evaluator, greedy_search
         
     | 
| 11 | 
         
            -
            from modules.span_rep import SpanRepLayer
         
     | 
| 12 | 
         
            -
            from modules.token_rep import TokenRepLayer
         
     | 
| 13 | 
         
            -
            from torch import nn
         
     | 
| 14 | 
         
            -
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 15 | 
         
            -
            from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
         
     | 
| 16 | 
         
            -
            from huggingface_hub.utils import HfHubHTTPError
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
            class GLiNER(InstructBase, PyTorchModelHubMixin):
         
     | 
| 21 | 
         
            -
                def __init__(self, config):
         
     | 
| 22 | 
         
            -
                    super().__init__(config)
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
                    self.config = config
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
                    # [ENT] token
         
     | 
| 27 | 
         
            -
                    self.entity_token = "<<ENT>>"
         
     | 
| 28 | 
         
            -
                    self.sep_token = "<<SEP>>"
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
                    # usually a pretrained bidirectional transformer, returns first subtoken representation
         
     | 
| 31 | 
         
            -
                    self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
         
     | 
| 32 | 
         
            -
                                                         subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
         
     | 
| 33 | 
         
            -
                                                         add_tokens=[self.entity_token, self.sep_token])
         
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
                    # hierarchical representation of tokens
         
     | 
| 36 | 
         
            -
                    self.rnn = LstmSeq2SeqEncoder(
         
     | 
| 37 | 
         
            -
                        input_size=config.hidden_size,
         
     | 
| 38 | 
         
            -
                        hidden_size=config.hidden_size // 2,
         
     | 
| 39 | 
         
            -
                        num_layers=1,
         
     | 
| 40 | 
         
            -
                        bidirectional=True,
         
     | 
| 41 | 
         
            -
                    )
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                    # span representation
         
     | 
| 44 | 
         
            -
                    self.span_rep_layer = SpanRepLayer(
         
     | 
| 45 | 
         
            -
                        span_mode=config.span_mode,
         
     | 
| 46 | 
         
            -
                        hidden_size=config.hidden_size,
         
     | 
| 47 | 
         
            -
                        max_width=config.max_width,
         
     | 
| 48 | 
         
            -
                        dropout=config.dropout,
         
     | 
| 49 | 
         
            -
                    )
         
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
                    # prompt representation (FFN)
         
     | 
| 52 | 
         
            -
                    self.prompt_rep_layer = nn.Sequential(
         
     | 
| 53 | 
         
            -
                        nn.Linear(config.hidden_size, config.hidden_size * 4),
         
     | 
| 54 | 
         
            -
                        nn.Dropout(config.dropout),
         
     | 
| 55 | 
         
            -
                        nn.ReLU(),
         
     | 
| 56 | 
         
            -
                        nn.Linear(config.hidden_size * 4, config.hidden_size)
         
     | 
| 57 | 
         
            -
                    )
         
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
                def compute_score_train(self, x):
         
     | 
| 60 | 
         
            -
                    span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
         
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
                    new_length = x['seq_length'].clone()
         
     | 
| 63 | 
         
            -
                    new_tokens = []
         
     | 
| 64 | 
         
            -
                    all_len_prompt = []
         
     | 
| 65 | 
         
            -
                    num_classes_all = []
         
     | 
| 66 | 
         
            -
             
     | 
| 67 | 
         
            -
                    # add prompt to the tokens
         
     | 
| 68 | 
         
            -
                    for i in range(len(x['tokens'])):
         
     | 
| 69 | 
         
            -
                        all_types_i = list(x['classes_to_id'][i].keys())
         
     | 
| 70 | 
         
            -
                        # multiple entity types in all_types. Prompt is appended at the start of tokens
         
     | 
| 71 | 
         
            -
                        entity_prompt = []
         
     | 
| 72 | 
         
            -
                        num_classes_all.append(len(all_types_i))
         
     | 
| 73 | 
         
            -
                        # add enity types to prompt
         
     | 
| 74 | 
         
            -
                        for entity_type in all_types_i:
         
     | 
| 75 | 
         
            -
                            entity_prompt.append(self.entity_token)  # [ENT] token
         
     | 
| 76 | 
         
            -
                            entity_prompt.append(entity_type)  # entity type
         
     | 
| 77 | 
         
            -
                        entity_prompt.append(self.sep_token)  # [SEP] token
         
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
                        # prompt format:
         
     | 
| 80 | 
         
            -
                        # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
                        # add prompt to the tokens
         
     | 
| 83 | 
         
            -
                        tokens_p = entity_prompt + x['tokens'][i]
         
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
                        # input format:
         
     | 
| 86 | 
         
            -
                        # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
         
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
                        # update length of the sequence (add prompt length to the original length)
         
     | 
| 89 | 
         
            -
                        new_length[i] = new_length[i] + len(entity_prompt)
         
     | 
| 90 | 
         
            -
                        # update tokens
         
     | 
| 91 | 
         
            -
                        new_tokens.append(tokens_p)
         
     | 
| 92 | 
         
            -
                        # store prompt length
         
     | 
| 93 | 
         
            -
                        all_len_prompt.append(len(entity_prompt))
         
     | 
| 94 | 
         
            -
             
     | 
| 95 | 
         
            -
                    # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
         
     | 
| 96 | 
         
            -
                    max_num_classes = max(num_classes_all)
         
     | 
| 97 | 
         
            -
                    entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
         
     | 
| 98 | 
         
            -
                        x['span_mask'].device)
         
     | 
| 99 | 
         
            -
                    entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
         
     | 
| 100 | 
         
            -
                        x['span_mask'].device)  # [batch_size, max_num_classes]
         
     | 
| 101 | 
         
            -
             
     | 
| 102 | 
         
            -
                    # compute all token representations
         
     | 
| 103 | 
         
            -
                    bert_output = self.token_rep_layer(new_tokens, new_length)
         
     | 
| 104 | 
         
            -
                    word_rep_w_prompt = bert_output["embeddings"]  # embeddings for all tokens (with prompt)
         
     | 
| 105 | 
         
            -
                    mask_w_prompt = bert_output["mask"]  # mask for all tokens (with prompt)
         
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
                    # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
         
     | 
| 108 | 
         
            -
                    word_rep = []  # word representation (after [SEP])
         
     | 
| 109 | 
         
            -
                    mask = []  # mask (after [SEP])
         
     | 
| 110 | 
         
            -
                    entity_type_rep = []  # entity type representation (before [SEP])
         
     | 
| 111 | 
         
            -
                    for i in range(len(x['tokens'])):
         
     | 
| 112 | 
         
            -
                        prompt_entity_length = all_len_prompt[i]  # length of prompt for this example
         
     | 
| 113 | 
         
            -
                        # get word representation (after [SEP])
         
     | 
| 114 | 
         
            -
                        word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
         
     | 
| 115 | 
         
            -
                        # get mask (after [SEP])
         
     | 
| 116 | 
         
            -
                        mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
         
     | 
| 117 | 
         
            -
             
     | 
| 118 | 
         
            -
                        # get entity type representation (before [SEP])
         
     | 
| 119 | 
         
            -
                        entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1]  # remove [SEP]
         
     | 
| 120 | 
         
            -
                        entity_rep = entity_rep[0::2]  # it means that we take every second element starting from the second one
         
     | 
| 121 | 
         
            -
                        entity_type_rep.append(entity_rep)
         
     | 
| 122 | 
         
            -
             
     | 
| 123 | 
         
            -
                    # padding for word_rep, mask and entity_type_rep
         
     | 
| 124 | 
         
            -
                    word_rep = pad_sequence(word_rep, batch_first=True)  # [batch_size, seq_len, hidden_size]
         
     | 
| 125 | 
         
            -
                    mask = pad_sequence(mask, batch_first=True)  # [batch_size, seq_len]
         
     | 
| 126 | 
         
            -
                    entity_type_rep = pad_sequence(entity_type_rep, batch_first=True)  # [batch_size, len_types, hidden_size]
         
     | 
| 127 | 
         
            -
             
     | 
| 128 | 
         
            -
                    # compute span representation
         
     | 
| 129 | 
         
            -
                    word_rep = self.rnn(word_rep, mask)
         
     | 
| 130 | 
         
            -
                    span_rep = self.span_rep_layer(word_rep, span_idx)
         
     | 
| 131 | 
         
            -
             
     | 
| 132 | 
         
            -
                    # compute final entity type representation (FFN)
         
     | 
| 133 | 
         
            -
                    entity_type_rep = self.prompt_rep_layer(entity_type_rep)  # (batch_size, len_types, hidden_size)
         
     | 
| 134 | 
         
            -
                    num_classes = entity_type_rep.shape[1]  # number of entity types
         
     | 
| 135 | 
         
            -
             
     | 
| 136 | 
         
            -
                    # similarity score
         
     | 
| 137 | 
         
            -
                    scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
         
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         
            -
                    return scores, num_classes, entity_type_mask
         
     | 
| 140 | 
         
            -
             
     | 
| 141 | 
         
            -
                def forward(self, x):
         
     | 
| 142 | 
         
            -
                    # compute span representation
         
     | 
| 143 | 
         
            -
                    scores, num_classes, entity_type_mask = self.compute_score_train(x)
         
     | 
| 144 | 
         
            -
                    batch_size = scores.shape[0]
         
     | 
| 145 | 
         
            -
             
     | 
| 146 | 
         
            -
                    # loss for filtering classifier
         
     | 
| 147 | 
         
            -
                    logits_label = scores.view(-1, num_classes)
         
     | 
| 148 | 
         
            -
                    labels = x["span_label"].view(-1)  # (batch_size * num_spans)
         
     | 
| 149 | 
         
            -
                    mask_label = labels != -1  # (batch_size * num_spans)
         
     | 
| 150 | 
         
            -
                    labels.masked_fill_(~mask_label, 0)  # Set the labels of padding tokens to 0
         
     | 
| 151 | 
         
            -
             
     | 
| 152 | 
         
            -
                    # one-hot encoding
         
     | 
| 153 | 
         
            -
                    labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
         
     | 
| 154 | 
         
            -
                    labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)  # Set the corresponding index to 1
         
     | 
| 155 | 
         
            -
                    labels_one_hot = labels_one_hot[:, 1:]  # Remove the first column
         
     | 
| 156 | 
         
            -
                    # Shape of labels_one_hot: (batch_size * num_spans, num_classes)
         
     | 
| 157 | 
         
            -
             
     | 
| 158 | 
         
            -
                    # compute loss (without reduction)
         
     | 
| 159 | 
         
            -
                    all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
         
     | 
| 160 | 
         
            -
                                                                    reduction='none')
         
     | 
| 161 | 
         
            -
                    # mask loss using entity_type_mask (B, C)
         
     | 
| 162 | 
         
            -
                    masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
         
     | 
| 163 | 
         
            -
                    all_losses = masked_loss.view(-1, num_classes)
         
     | 
| 164 | 
         
            -
                    # expand mask_label to all_losses
         
     | 
| 165 | 
         
            -
                    mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
         
     | 
| 166 | 
         
            -
                    # put lower loss for in label_one_hot (2 for positive, 1 for negative)
         
     | 
| 167 | 
         
            -
                    weight_c = labels_one_hot + 1
         
     | 
| 168 | 
         
            -
                    # apply mask
         
     | 
| 169 | 
         
            -
                    all_losses = all_losses * mask_label.float() * weight_c
         
     | 
| 170 | 
         
            -
                    return all_losses.sum()
         
     | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
                def compute_score_eval(self, x, device):
         
     | 
| 173 | 
         
            -
                    # check if classes_to_id is dict
         
     | 
| 174 | 
         
            -
                    assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
         
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
                    span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
         
     | 
| 177 | 
         
            -
             
     | 
| 178 | 
         
            -
                    all_types = list(x['classes_to_id'].keys())
         
     | 
| 179 | 
         
            -
                    # multiple entity types in all_types. Prompt is appended at the start of tokens
         
     | 
| 180 | 
         
            -
                    entity_prompt = []
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
                    # add enity types to prompt
         
     | 
| 183 | 
         
            -
                    for entity_type in all_types:
         
     | 
| 184 | 
         
            -
                        entity_prompt.append(self.entity_token)
         
     | 
| 185 | 
         
            -
                        entity_prompt.append(entity_type)
         
     | 
| 186 | 
         
            -
             
     | 
| 187 | 
         
            -
                    entity_prompt.append(self.sep_token)
         
     | 
| 188 | 
         
            -
             
     | 
| 189 | 
         
            -
                    prompt_entity_length = len(entity_prompt)
         
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
            -
                    # add prompt
         
     | 
| 192 | 
         
            -
                    tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
         
     | 
| 193 | 
         
            -
                    seq_length_p = x['seq_length'] + prompt_entity_length
         
     | 
| 194 | 
         
            -
             
     | 
| 195 | 
         
            -
                    out = self.token_rep_layer(tokens_p, seq_length_p)
         
     | 
| 196 | 
         
            -
             
     | 
| 197 | 
         
            -
                    word_rep_w_prompt = out["embeddings"]
         
     | 
| 198 | 
         
            -
                    mask_w_prompt = out["mask"]
         
     | 
| 199 | 
         
            -
             
     | 
| 200 | 
         
            -
                    # remove prompt
         
     | 
| 201 | 
         
            -
                    word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
         
     | 
| 202 | 
         
            -
                    mask = mask_w_prompt[:, prompt_entity_length:]
         
     | 
| 203 | 
         
            -
             
     | 
| 204 | 
         
            -
                    # get_entity_type_rep
         
     | 
| 205 | 
         
            -
                    entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
         
     | 
| 206 | 
         
            -
                    # extract [ENT] tokens (which are at even positions in entity_type_rep)
         
     | 
| 207 | 
         
            -
                    entity_type_rep = entity_type_rep[:, 0::2, :]
         
     | 
| 208 | 
         
            -
             
     | 
| 209 | 
         
            -
                    entity_type_rep = self.prompt_rep_layer(entity_type_rep)  # (batch_size, len_types, hidden_size)
         
     | 
| 210 | 
         
            -
             
     | 
| 211 | 
         
            -
                    word_rep = self.rnn(word_rep, mask)
         
     | 
| 212 | 
         
            -
             
     | 
| 213 | 
         
            -
                    span_rep = self.span_rep_layer(word_rep, span_idx)
         
     | 
| 214 | 
         
            -
             
     | 
| 215 | 
         
            -
                    local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
         
     | 
| 216 | 
         
            -
             
     | 
| 217 | 
         
            -
                    return local_scores
         
     | 
| 218 | 
         
            -
             
     | 
| 219 | 
         
            -
                @torch.no_grad()
         
     | 
| 220 | 
         
            -
                def predict(self, x, flat_ner=False, threshold=0.5):
         
     | 
| 221 | 
         
            -
                    self.eval()
         
     | 
| 222 | 
         
            -
                    local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
         
     | 
| 223 | 
         
            -
                    spans = []
         
     | 
| 224 | 
         
            -
                    for i, _ in enumerate(x["tokens"]):
         
     | 
| 225 | 
         
            -
                        local_i = local_scores[i]
         
     | 
| 226 | 
         
            -
                        wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
         
     | 
| 227 | 
         
            -
                        span_i = []
         
     | 
| 228 | 
         
            -
                        for s, k, c in zip(*wh_i):
         
     | 
| 229 | 
         
            -
                            if s + k < len(x["tokens"][i]):
         
     | 
| 230 | 
         
            -
                                span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
         
     | 
| 231 | 
         
            -
                        span_i = greedy_search(span_i, flat_ner)
         
     | 
| 232 | 
         
            -
                        spans.append(span_i)
         
     | 
| 233 | 
         
            -
                    return spans
         
     | 
| 234 | 
         
            -
             
     | 
| 235 | 
         
            -
                def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
         
     | 
| 236 | 
         
            -
                    tokens = []
         
     | 
| 237 | 
         
            -
                    start_token_idx_to_text_idx = []
         
     | 
| 238 | 
         
            -
                    end_token_idx_to_text_idx = []
         
     | 
| 239 | 
         
            -
                    for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
         
     | 
| 240 | 
         
            -
                        tokens.append(match.group())
         
     | 
| 241 | 
         
            -
                        start_token_idx_to_text_idx.append(match.start())
         
     | 
| 242 | 
         
            -
                        end_token_idx_to_text_idx.append(match.end())
         
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
                    input_x = {"tokenized_text": tokens, "ner": None}
         
     | 
| 245 | 
         
            -
                    x = self.collate_fn([input_x], labels)
         
     | 
| 246 | 
         
            -
                    output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
         
     | 
| 247 | 
         
            -
             
     | 
| 248 | 
         
            -
                    entities = []
         
     | 
| 249 | 
         
            -
                    for start_token_idx, end_token_idx, ent_type in output[0]:
         
     | 
| 250 | 
         
            -
                        start_text_idx = start_token_idx_to_text_idx[start_token_idx]
         
     | 
| 251 | 
         
            -
                        end_text_idx = end_token_idx_to_text_idx[end_token_idx]
         
     | 
| 252 | 
         
            -
                        entities.append({
         
     | 
| 253 | 
         
            -
                            "start": start_token_idx_to_text_idx[start_token_idx],
         
     | 
| 254 | 
         
            -
                            "end": end_token_idx_to_text_idx[end_token_idx],
         
     | 
| 255 | 
         
            -
                            "text": text[start_text_idx:end_text_idx],
         
     | 
| 256 | 
         
            -
                            "label": ent_type,
         
     | 
| 257 | 
         
            -
                        })
         
     | 
| 258 | 
         
            -
                    return entities
         
     | 
| 259 | 
         
            -
             
     | 
| 260 | 
         
            -
                def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
         
     | 
| 261 | 
         
            -
                    self.eval()
         
     | 
| 262 | 
         
            -
                    data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
         
     | 
| 263 | 
         
            -
                    device = next(self.parameters()).device
         
     | 
| 264 | 
         
            -
                    all_preds = []
         
     | 
| 265 | 
         
            -
                    all_trues = []
         
     | 
| 266 | 
         
            -
                    for x in data_loader:
         
     | 
| 267 | 
         
            -
                        for k, v in x.items():
         
     | 
| 268 | 
         
            -
                            if isinstance(v, torch.Tensor):
         
     | 
| 269 | 
         
            -
                                x[k] = v.to(device)
         
     | 
| 270 | 
         
            -
                        batch_predictions = self.predict(x, flat_ner, threshold)
         
     | 
| 271 | 
         
            -
                        all_preds.extend(batch_predictions)
         
     | 
| 272 | 
         
            -
                        all_trues.extend(x["entities"])
         
     | 
| 273 | 
         
            -
                    evaluator = Evaluator(all_trues, all_preds)
         
     | 
| 274 | 
         
            -
                    out, f1 = evaluator.evaluate()
         
     | 
| 275 | 
         
            -
                    return out, f1
         
     | 
| 276 | 
         
            -
             
     | 
| 277 | 
         
            -
                @classmethod
         
     | 
| 278 | 
         
            -
                def _from_pretrained(
         
     | 
| 279 | 
         
            -
                    cls,
         
     | 
| 280 | 
         
            -
                    *,
         
     | 
| 281 | 
         
            -
                    model_id: str,
         
     | 
| 282 | 
         
            -
                    revision: Optional[str],
         
     | 
| 283 | 
         
            -
                    cache_dir: Optional[Union[str, Path]],
         
     | 
| 284 | 
         
            -
                    force_download: bool,
         
     | 
| 285 | 
         
            -
                    proxies: Optional[Dict],
         
     | 
| 286 | 
         
            -
                    resume_download: bool,
         
     | 
| 287 | 
         
            -
                    local_files_only: bool,
         
     | 
| 288 | 
         
            -
                    token: Union[str, bool, None],
         
     | 
| 289 | 
         
            -
                    map_location: str = "cpu",
         
     | 
| 290 | 
         
            -
                    strict: bool = False,
         
     | 
| 291 | 
         
            -
                    **model_kwargs,
         
     | 
| 292 | 
         
            -
                ):
         
     | 
| 293 | 
         
            -
                    # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
         
     | 
| 294 | 
         
            -
                    filenames = ["gliner_base.pt", "gliner_multi.pt"]
         
     | 
| 295 | 
         
            -
                    for filename in filenames:
         
     | 
| 296 | 
         
            -
                        model_file = Path(model_id) / filename
         
     | 
| 297 | 
         
            -
                        if not model_file.exists():
         
     | 
| 298 | 
         
            -
                            try:
         
     | 
| 299 | 
         
            -
                                model_file = hf_hub_download(
         
     | 
| 300 | 
         
            -
                                    repo_id=model_id,
         
     | 
| 301 | 
         
            -
                                    filename=filename,
         
     | 
| 302 | 
         
            -
                                    revision=revision,
         
     | 
| 303 | 
         
            -
                                    cache_dir=cache_dir,
         
     | 
| 304 | 
         
            -
                                    force_download=force_download,
         
     | 
| 305 | 
         
            -
                                    proxies=proxies,
         
     | 
| 306 | 
         
            -
                                    resume_download=resume_download,
         
     | 
| 307 | 
         
            -
                                    token=token,
         
     | 
| 308 | 
         
            -
                                    local_files_only=local_files_only,
         
     | 
| 309 | 
         
            -
                                )
         
     | 
| 310 | 
         
            -
                            except HfHubHTTPError:
         
     | 
| 311 | 
         
            -
                                continue
         
     | 
| 312 | 
         
            -
                        dict_load = torch.load(model_file, map_location=torch.device(map_location))
         
     | 
| 313 | 
         
            -
                        config = dict_load["config"]
         
     | 
| 314 | 
         
            -
                        state_dict = dict_load["model_weights"]
         
     | 
| 315 | 
         
            -
                        config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
         
     | 
| 316 | 
         
            -
                        model = cls(config)
         
     | 
| 317 | 
         
            -
                        model.load_state_dict(state_dict, strict=strict, assign=True)
         
     | 
| 318 | 
         
            -
                        # Required to update flair's internals as well:
         
     | 
| 319 | 
         
            -
                        model.to(map_location)
         
     | 
| 320 | 
         
            -
                        return model
         
     | 
| 321 | 
         
            -
             
     | 
| 322 | 
         
            -
                    # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
         
     | 
| 323 | 
         
            -
                    from train import load_config_as_namespace
         
     | 
| 324 | 
         
            -
             
     | 
| 325 | 
         
            -
                    model_file = Path(model_id) / "pytorch_model.bin"
         
     | 
| 326 | 
         
            -
                    if not model_file.exists():
         
     | 
| 327 | 
         
            -
                        model_file = hf_hub_download(
         
     | 
| 328 | 
         
            -
                            repo_id=model_id,
         
     | 
| 329 | 
         
            -
                            filename="pytorch_model.bin",
         
     | 
| 330 | 
         
            -
                            revision=revision,
         
     | 
| 331 | 
         
            -
                            cache_dir=cache_dir,
         
     | 
| 332 | 
         
            -
                            force_download=force_download,
         
     | 
| 333 | 
         
            -
                            proxies=proxies,
         
     | 
| 334 | 
         
            -
                            resume_download=resume_download,
         
     | 
| 335 | 
         
            -
                            token=token,
         
     | 
| 336 | 
         
            -
                            local_files_only=local_files_only,
         
     | 
| 337 | 
         
            -
                        )
         
     | 
| 338 | 
         
            -
                    config_file = Path(model_id) / "gliner_config.json"
         
     | 
| 339 | 
         
            -
                    if not config_file.exists():
         
     | 
| 340 | 
         
            -
                        config_file = hf_hub_download(
         
     | 
| 341 | 
         
            -
                            repo_id=model_id,
         
     | 
| 342 | 
         
            -
                            filename="gliner_config.json",
         
     | 
| 343 | 
         
            -
                            revision=revision,
         
     | 
| 344 | 
         
            -
                            cache_dir=cache_dir,
         
     | 
| 345 | 
         
            -
                            force_download=force_download,
         
     | 
| 346 | 
         
            -
                            proxies=proxies,
         
     | 
| 347 | 
         
            -
                            resume_download=resume_download,
         
     | 
| 348 | 
         
            -
                            token=token,
         
     | 
| 349 | 
         
            -
                            local_files_only=local_files_only,
         
     | 
| 350 | 
         
            -
                        )
         
     | 
| 351 | 
         
            -
                    config = load_config_as_namespace(config_file)
         
     | 
| 352 | 
         
            -
                    model = cls(config)
         
     | 
| 353 | 
         
            -
                    state_dict = torch.load(model_file, map_location=torch.device(map_location))
         
     | 
| 354 | 
         
            -
                    model.load_state_dict(state_dict, strict=strict, assign=True)
         
     | 
| 355 | 
         
            -
                    model.to(map_location)
         
     | 
| 356 | 
         
            -
                    return model
         
     | 
| 357 | 
         
            -
             
     | 
| 358 | 
         
            -
                def save_pretrained(
         
     | 
| 359 | 
         
            -
                    self,
         
     | 
| 360 | 
         
            -
                    save_directory: Union[str, Path],
         
     | 
| 361 | 
         
            -
                    *,
         
     | 
| 362 | 
         
            -
                    config: Optional[Union[dict, "DataclassInstance"]] = None,
         
     | 
| 363 | 
         
            -
                    repo_id: Optional[str] = None,
         
     | 
| 364 | 
         
            -
                    push_to_hub: bool = False,
         
     | 
| 365 | 
         
            -
                    **push_to_hub_kwargs,
         
     | 
| 366 | 
         
            -
                ) -> Optional[str]:
         
     | 
| 367 | 
         
            -
                    """
         
     | 
| 368 | 
         
            -
                    Save weights in local directory.
         
     | 
| 369 | 
         
            -
             
     | 
| 370 | 
         
            -
                    Args:
         
     | 
| 371 | 
         
            -
                        save_directory (`str` or `Path`):
         
     | 
| 372 | 
         
            -
                            Path to directory in which the model weights and configuration will be saved.
         
     | 
| 373 | 
         
            -
                        config (`dict` or `DataclassInstance`, *optional*):
         
     | 
| 374 | 
         
            -
                            Model configuration specified as a key/value dictionary or a dataclass instance.
         
     | 
| 375 | 
         
            -
                        push_to_hub (`bool`, *optional*, defaults to `False`):
         
     | 
| 376 | 
         
            -
                            Whether or not to push your model to the Huggingface Hub after saving it.
         
     | 
| 377 | 
         
            -
                        repo_id (`str`, *optional*):
         
     | 
| 378 | 
         
            -
                            ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
         
     | 
| 379 | 
         
            -
                            not provided.
         
     | 
| 380 | 
         
            -
                        kwargs:
         
     | 
| 381 | 
         
            -
                            Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
         
     | 
| 382 | 
         
            -
                    """
         
     | 
| 383 | 
         
            -
                    save_directory = Path(save_directory)
         
     | 
| 384 | 
         
            -
                    save_directory.mkdir(parents=True, exist_ok=True)
         
     | 
| 385 | 
         
            -
             
     | 
| 386 | 
         
            -
                    # save model weights/files
         
     | 
| 387 | 
         
            -
                    torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
         
     | 
| 388 | 
         
            -
             
     | 
| 389 | 
         
            -
                    # save config (if provided)
         
     | 
| 390 | 
         
            -
                    if config is None:
         
     | 
| 391 | 
         
            -
                        config = self.config
         
     | 
| 392 | 
         
            -
                    if config is not None:
         
     | 
| 393 | 
         
            -
                        if isinstance(config, argparse.Namespace):
         
     | 
| 394 | 
         
            -
                            config = vars(config)
         
     | 
| 395 | 
         
            -
                        (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
         
     | 
| 396 | 
         
            -
             
     | 
| 397 | 
         
            -
                    # push to the Hub if required
         
     | 
| 398 | 
         
            -
                    if push_to_hub:
         
     | 
| 399 | 
         
            -
                        kwargs = push_to_hub_kwargs.copy()  # soft-copy to avoid mutating input
         
     | 
| 400 | 
         
            -
                        if config is not None:  # kwarg for `push_to_hub`
         
     | 
| 401 | 
         
            -
                            kwargs["config"] = config
         
     | 
| 402 | 
         
            -
                        if repo_id is None:
         
     | 
| 403 | 
         
            -
                            repo_id = save_directory.name  # Defaults to `save_directory` name
         
     | 
| 404 | 
         
            -
                        return self.push_to_hub(repo_id=repo_id, **kwargs)
         
     | 
| 405 | 
         
            -
                    return None
         
     | 
| 406 | 
         
            -
             
     | 
| 407 | 
         
            -
                def to(self, device):
         
     | 
| 408 | 
         
            -
                    super().to(device)
         
     | 
| 409 | 
         
            -
                    import flair
         
     | 
| 410 | 
         
            -
             
     | 
| 411 | 
         
            -
                    flair.device = device
         
     | 
| 412 | 
         
            -
                    return self
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/modules/base.py
    DELETED
    
    | 
         @@ -1,150 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from collections import defaultdict
         
     | 
| 2 | 
         
            -
            from typing import List, Tuple, Dict
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            import torch
         
     | 
| 5 | 
         
            -
            from torch import nn
         
     | 
| 6 | 
         
            -
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 7 | 
         
            -
            from torch.utils.data import DataLoader
         
     | 
| 8 | 
         
            -
            import random
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            class InstructBase(nn.Module):
         
     | 
| 12 | 
         
            -
                def __init__(self, config):
         
     | 
| 13 | 
         
            -
                    super().__init__()
         
     | 
| 14 | 
         
            -
                    self.max_width = config.max_width
         
     | 
| 15 | 
         
            -
                    self.base_config = config
         
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
                def get_dict(self, spans, classes_to_id):
         
     | 
| 18 | 
         
            -
                    dict_tag = defaultdict(int)
         
     | 
| 19 | 
         
            -
                    for span in spans:
         
     | 
| 20 | 
         
            -
                        if span[2] in classes_to_id:
         
     | 
| 21 | 
         
            -
                            dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
         
     | 
| 22 | 
         
            -
                    return dict_tag
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
                def preprocess_spans(self, tokens, ner, classes_to_id):
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
                    max_len = self.base_config.max_len
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
                    if len(tokens) > max_len:
         
     | 
| 29 | 
         
            -
                        length = max_len
         
     | 
| 30 | 
         
            -
                        tokens = tokens[:max_len]
         
     | 
| 31 | 
         
            -
                    else:
         
     | 
| 32 | 
         
            -
                        length = len(tokens)
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
                    spans_idx = []
         
     | 
| 35 | 
         
            -
                    for i in range(length):
         
     | 
| 36 | 
         
            -
                        spans_idx.extend([(i, i + j) for j in range(self.max_width)])
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                    dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
         
     | 
| 39 | 
         
            -
             
     | 
| 40 | 
         
            -
                    # 0 for null labels
         
     | 
| 41 | 
         
            -
                    span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
         
     | 
| 42 | 
         
            -
                    spans_idx = torch.LongTensor(spans_idx)
         
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
                    # mask for valid spans
         
     | 
| 45 | 
         
            -
                    valid_span_mask = spans_idx[:, 1] > length - 1
         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
                    # mask invalid positions
         
     | 
| 48 | 
         
            -
                    span_label = span_label.masked_fill(valid_span_mask, -1)
         
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
                    return {
         
     | 
| 51 | 
         
            -
                        'tokens': tokens,
         
     | 
| 52 | 
         
            -
                        'span_idx': spans_idx,
         
     | 
| 53 | 
         
            -
                        'span_label': span_label,
         
     | 
| 54 | 
         
            -
                        'seq_length': length,
         
     | 
| 55 | 
         
            -
                        'entities': ner,
         
     | 
| 56 | 
         
            -
                    }
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
                def collate_fn(self, batch_list, entity_types=None):
         
     | 
| 59 | 
         
            -
                    # batch_list: list of dict containing tokens, ner
         
     | 
| 60 | 
         
            -
                    if entity_types is None:
         
     | 
| 61 | 
         
            -
                        negs = self.get_negatives(batch_list, 100)
         
     | 
| 62 | 
         
            -
                        class_to_ids = []
         
     | 
| 63 | 
         
            -
                        id_to_classes = []
         
     | 
| 64 | 
         
            -
                        for b in batch_list:
         
     | 
| 65 | 
         
            -
                            # negs = b["negative"]
         
     | 
| 66 | 
         
            -
                            random.shuffle(negs)
         
     | 
| 67 | 
         
            -
             
     | 
| 68 | 
         
            -
                            # negs = negs[:sampled_neg]
         
     | 
| 69 | 
         
            -
                            max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
         
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
                            if max_neg_type_ratio == 0:
         
     | 
| 72 | 
         
            -
                                # no negatives
         
     | 
| 73 | 
         
            -
                                neg_type_ratio = 0
         
     | 
| 74 | 
         
            -
                            else:
         
     | 
| 75 | 
         
            -
                                neg_type_ratio = random.randint(0, max_neg_type_ratio)
         
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
                            if neg_type_ratio == 0:
         
     | 
| 78 | 
         
            -
                                # no negatives
         
     | 
| 79 | 
         
            -
                                negs_i = []
         
     | 
| 80 | 
         
            -
                            else:
         
     | 
| 81 | 
         
            -
                                negs_i = negs[:len(b['ner']) * neg_type_ratio]
         
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
                            # this is the list of all possible entity types (positive and negative)
         
     | 
| 84 | 
         
            -
                            types = list(set([el[-1] for el in b['ner']] + negs_i))
         
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
                            # shuffle (every epoch)
         
     | 
| 87 | 
         
            -
                            random.shuffle(types)
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
                            if len(types) != 0:
         
     | 
| 90 | 
         
            -
                                # prob of higher number shoul
         
     | 
| 91 | 
         
            -
                                # random drop
         
     | 
| 92 | 
         
            -
                                if self.base_config.random_drop:
         
     | 
| 93 | 
         
            -
                                    num_ents = random.randint(1, len(types))
         
     | 
| 94 | 
         
            -
                                    types = types[:num_ents]
         
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
                            # maximum number of entities types
         
     | 
| 97 | 
         
            -
                            types = types[:int(self.base_config.max_types)]
         
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
                            # supervised training
         
     | 
| 100 | 
         
            -
                            if "label" in b:
         
     | 
| 101 | 
         
            -
                                types = sorted(b["label"])
         
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
                            class_to_id = {k: v for v, k in enumerate(types, start=1)}
         
     | 
| 104 | 
         
            -
                            id_to_class = {k: v for v, k in class_to_id.items()}
         
     | 
| 105 | 
         
            -
                            class_to_ids.append(class_to_id)
         
     | 
| 106 | 
         
            -
                            id_to_classes.append(id_to_class)
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
                        batch = [
         
     | 
| 109 | 
         
            -
                            self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
         
     | 
| 110 | 
         
            -
                        ]
         
     | 
| 111 | 
         
            -
             
     | 
| 112 | 
         
            -
                    else:
         
     | 
| 113 | 
         
            -
                        class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
         
     | 
| 114 | 
         
            -
                        id_to_classes = {k: v for v, k in class_to_ids.items()}
         
     | 
| 115 | 
         
            -
                        batch = [
         
     | 
| 116 | 
         
            -
                            self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
         
     | 
| 117 | 
         
            -
                        ]
         
     | 
| 118 | 
         
            -
             
     | 
| 119 | 
         
            -
                    span_idx = pad_sequence(
         
     | 
| 120 | 
         
            -
                        [b['span_idx'] for b in batch], batch_first=True, padding_value=0
         
     | 
| 121 | 
         
            -
                    )
         
     | 
| 122 | 
         
            -
             
     | 
| 123 | 
         
            -
                    span_label = pad_sequence(
         
     | 
| 124 | 
         
            -
                        [el['span_label'] for el in batch], batch_first=True, padding_value=-1
         
     | 
| 125 | 
         
            -
                    )
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                    return {
         
     | 
| 128 | 
         
            -
                        'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
         
     | 
| 129 | 
         
            -
                        'span_idx': span_idx,
         
     | 
| 130 | 
         
            -
                        'tokens': [el['tokens'] for el in batch],
         
     | 
| 131 | 
         
            -
                        'span_mask': span_label != -1,
         
     | 
| 132 | 
         
            -
                        'span_label': span_label,
         
     | 
| 133 | 
         
            -
                        'entities': [el['entities'] for el in batch],
         
     | 
| 134 | 
         
            -
                        'classes_to_id': class_to_ids,
         
     | 
| 135 | 
         
            -
                        'id_to_classes': id_to_classes,
         
     | 
| 136 | 
         
            -
                    }
         
     | 
| 137 | 
         
            -
             
     | 
| 138 | 
         
            -
                @staticmethod
         
     | 
| 139 | 
         
            -
                def get_negatives(batch_list, sampled_neg=5):
         
     | 
| 140 | 
         
            -
                    ent_types = []
         
     | 
| 141 | 
         
            -
                    for b in batch_list:
         
     | 
| 142 | 
         
            -
                        types = set([el[-1] for el in b['ner']])
         
     | 
| 143 | 
         
            -
                        ent_types.extend(list(types))
         
     | 
| 144 | 
         
            -
                    ent_types = list(set(ent_types))
         
     | 
| 145 | 
         
            -
                    # sample negatives
         
     | 
| 146 | 
         
            -
                    random.shuffle(ent_types)
         
     | 
| 147 | 
         
            -
                    return ent_types[:sampled_neg]
         
     | 
| 148 | 
         
            -
             
     | 
| 149 | 
         
            -
                def create_dataloader(self, data, entity_types=None, **kwargs):
         
     | 
| 150 | 
         
            -
                    return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/modules/data_proc.py
    DELETED
    
    | 
         @@ -1,73 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import json
         
     | 
| 2 | 
         
            -
            from tqdm import tqdm
         
     | 
| 3 | 
         
            -
            # ast.literal_eval
         
     | 
| 4 | 
         
            -
            import ast, re
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            path = 'train.json'
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
            with open(path, 'r') as f:
         
     | 
| 9 | 
         
            -
                data = json.load(f)
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            def tokenize_text(text):
         
     | 
| 12 | 
         
            -
                return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
         
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            def extract_entity_spans(entry):
         
     | 
| 15 | 
         
            -
                text = ""
         
     | 
| 16 | 
         
            -
                len_start = len("What describes ")
         
     | 
| 17 | 
         
            -
                len_end = len(" in the text?")
         
     | 
| 18 | 
         
            -
                entity_types = []
         
     | 
| 19 | 
         
            -
                entity_texts = []
         
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
                for c in entry['conversations']:
         
     | 
| 22 | 
         
            -
                    if c['from'] == 'human' and c['value'].startswith('Text: '):
         
     | 
| 23 | 
         
            -
                        text = c['value'][len('Text: '):]
         
     | 
| 24 | 
         
            -
                        tokenized_text = tokenize_text(text)
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
                    if c['from'] == 'human' and c['value'].startswith('What describes '):
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
                        c_type = c['value'][len_start:-len_end]
         
     | 
| 29 | 
         
            -
                        c_type = c_type.replace(' ', '_')
         
     | 
| 30 | 
         
            -
                        entity_types.append(c_type)
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
                    elif c['from'] == 'gpt' and c['value'].startswith('['):
         
     | 
| 33 | 
         
            -
                        if c['value'] == '[]':
         
     | 
| 34 | 
         
            -
                            entity_types = entity_types[:-1]
         
     | 
| 35 | 
         
            -
                            continue
         
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
                        texts_ents = ast.literal_eval(c['value'])
         
     | 
| 38 | 
         
            -
                        # replace space to _ in texts_ents
         
     | 
| 39 | 
         
            -
                        entity_texts.extend(texts_ents)
         
     | 
| 40 | 
         
            -
                        num_repeat = len(texts_ents) - 1
         
     | 
| 41 | 
         
            -
                        entity_types.extend([entity_types[-1]] * num_repeat)
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                entity_spans = []
         
     | 
| 44 | 
         
            -
                for j, entity_text in enumerate(entity_texts):
         
     | 
| 45 | 
         
            -
                    entity_tokens = tokenize_text(entity_text)
         
     | 
| 46 | 
         
            -
                    matches = []
         
     | 
| 47 | 
         
            -
                    for i in range(len(tokenized_text) - len(entity_tokens) + 1):
         
     | 
| 48 | 
         
            -
                        if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
         
     | 
| 49 | 
         
            -
                            matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
         
     | 
| 50 | 
         
            -
                    if matches:
         
     | 
| 51 | 
         
            -
                        entity_spans.extend(matches)
         
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
                return entity_spans, tokenized_text
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
            # Usage:
         
     | 
| 56 | 
         
            -
            # Replace 'entry' with the specific entry from your JSON data
         
     | 
| 57 | 
         
            -
            entry = data[17818]  # For example, taking the first entry
         
     | 
| 58 | 
         
            -
            entity_spans, tokenized_text = extract_entity_spans(entry)
         
     | 
| 59 | 
         
            -
            print("Entity Spans:", entity_spans)
         
     | 
| 60 | 
         
            -
            #print("Tokenized Text:", tokenized_text)
         
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
            # create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
            all_data = []
         
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
            for entry in tqdm(data):
         
     | 
| 67 | 
         
            -
                entity_spans, tokenized_text = extract_entity_spans(entry)
         
     | 
| 68 | 
         
            -
                all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
            with open('train_instruct.json', 'w') as f:
         
     | 
| 72 | 
         
            -
                json.dump(all_data, f)
         
     | 
| 73 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/modules/evaluator.py
    DELETED
    
    | 
         @@ -1,152 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from collections import defaultdict
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            import numpy as np
         
     | 
| 4 | 
         
            -
            import torch
         
     | 
| 5 | 
         
            -
            from seqeval.metrics.v1 import _prf_divide
         
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
            def extract_tp_actual_correct(y_true, y_pred):
         
     | 
| 9 | 
         
            -
                entities_true = defaultdict(set)
         
     | 
| 10 | 
         
            -
                entities_pred = defaultdict(set)
         
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
                for type_name, (start, end), idx in y_true:
         
     | 
| 13 | 
         
            -
                    entities_true[type_name].add((start, end, idx))
         
     | 
| 14 | 
         
            -
                for type_name, (start, end), idx in y_pred:
         
     | 
| 15 | 
         
            -
                    entities_pred[type_name].add((start, end, idx))
         
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
                target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
                tp_sum = np.array([], dtype=np.int32)
         
     | 
| 20 | 
         
            -
                pred_sum = np.array([], dtype=np.int32)
         
     | 
| 21 | 
         
            -
                true_sum = np.array([], dtype=np.int32)
         
     | 
| 22 | 
         
            -
                for type_name in target_names:
         
     | 
| 23 | 
         
            -
                    entities_true_type = entities_true.get(type_name, set())
         
     | 
| 24 | 
         
            -
                    entities_pred_type = entities_pred.get(type_name, set())
         
     | 
| 25 | 
         
            -
                    tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
         
     | 
| 26 | 
         
            -
                    pred_sum = np.append(pred_sum, len(entities_pred_type))
         
     | 
| 27 | 
         
            -
                    true_sum = np.append(true_sum, len(entities_true_type))
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
                return pred_sum, tp_sum, true_sum, target_names
         
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
            def flatten_for_eval(y_true, y_pred):
         
     | 
| 33 | 
         
            -
                all_true = []
         
     | 
| 34 | 
         
            -
                all_pred = []
         
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
                for i, (true, pred) in enumerate(zip(y_true, y_pred)):
         
     | 
| 37 | 
         
            -
                    all_true.extend([t + [i] for t in true])
         
     | 
| 38 | 
         
            -
                    all_pred.extend([p + [i] for p in pred])
         
     | 
| 39 | 
         
            -
             
     | 
| 40 | 
         
            -
                return all_true, all_pred
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
            def compute_prf(y_true, y_pred, average='micro'):
         
     | 
| 44 | 
         
            -
                y_true, y_pred = flatten_for_eval(y_true, y_pred)
         
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
                pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
                if average == 'micro':
         
     | 
| 49 | 
         
            -
                    tp_sum = np.array([tp_sum.sum()])
         
     | 
| 50 | 
         
            -
                    pred_sum = np.array([pred_sum.sum()])
         
     | 
| 51 | 
         
            -
                    true_sum = np.array([true_sum.sum()])
         
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
                precision = _prf_divide(
         
     | 
| 54 | 
         
            -
                    numerator=tp_sum,
         
     | 
| 55 | 
         
            -
                    denominator=pred_sum,
         
     | 
| 56 | 
         
            -
                    metric='precision',
         
     | 
| 57 | 
         
            -
                    modifier='predicted',
         
     | 
| 58 | 
         
            -
                    average=average,
         
     | 
| 59 | 
         
            -
                    warn_for=('precision', 'recall', 'f-score'),
         
     | 
| 60 | 
         
            -
                    zero_division='warn'
         
     | 
| 61 | 
         
            -
                )
         
     | 
| 62 | 
         
            -
             
     | 
| 63 | 
         
            -
                recall = _prf_divide(
         
     | 
| 64 | 
         
            -
                    numerator=tp_sum,
         
     | 
| 65 | 
         
            -
                    denominator=true_sum,
         
     | 
| 66 | 
         
            -
                    metric='recall',
         
     | 
| 67 | 
         
            -
                    modifier='true',
         
     | 
| 68 | 
         
            -
                    average=average,
         
     | 
| 69 | 
         
            -
                    warn_for=('precision', 'recall', 'f-score'),
         
     | 
| 70 | 
         
            -
                    zero_division='warn'
         
     | 
| 71 | 
         
            -
                )
         
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                denominator = precision + recall
         
     | 
| 74 | 
         
            -
                denominator[denominator == 0.] = 1
         
     | 
| 75 | 
         
            -
                f_score = 2 * (precision * recall) / denominator
         
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
                return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
         
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
             
     | 
| 80 | 
         
            -
            class Evaluator:
         
     | 
| 81 | 
         
            -
                def __init__(self, all_true, all_outs):
         
     | 
| 82 | 
         
            -
                    self.all_true = all_true
         
     | 
| 83 | 
         
            -
                    self.all_outs = all_outs
         
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
                def get_entities_fr(self, ents):
         
     | 
| 86 | 
         
            -
                    all_ents = []
         
     | 
| 87 | 
         
            -
                    for s, e, lab in ents:
         
     | 
| 88 | 
         
            -
                        all_ents.append([lab, (s, e)])
         
     | 
| 89 | 
         
            -
                    return all_ents
         
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
                def transform_data(self):
         
     | 
| 92 | 
         
            -
                    all_true_ent = []
         
     | 
| 93 | 
         
            -
                    all_outs_ent = []
         
     | 
| 94 | 
         
            -
                    for i, j in zip(self.all_true, self.all_outs):
         
     | 
| 95 | 
         
            -
                        e = self.get_entities_fr(i)
         
     | 
| 96 | 
         
            -
                        all_true_ent.append(e)
         
     | 
| 97 | 
         
            -
                        e = self.get_entities_fr(j)
         
     | 
| 98 | 
         
            -
                        all_outs_ent.append(e)
         
     | 
| 99 | 
         
            -
                    return all_true_ent, all_outs_ent
         
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
                @torch.no_grad()
         
     | 
| 102 | 
         
            -
                def evaluate(self):
         
     | 
| 103 | 
         
            -
                    all_true_typed, all_outs_typed = self.transform_data()
         
     | 
| 104 | 
         
            -
                    precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
         
     | 
| 105 | 
         
            -
                    output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
         
     | 
| 106 | 
         
            -
                    return output_str, f1
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
            def is_nested(idx1, idx2):
         
     | 
| 110 | 
         
            -
                # Return True if idx2 is nested inside idx1 or vice versa
         
     | 
| 111 | 
         
            -
                return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
         
     | 
| 112 | 
         
            -
             
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
            def has_overlapping(idx1, idx2):
         
     | 
| 115 | 
         
            -
                overlapping = True
         
     | 
| 116 | 
         
            -
                if idx1[:2] == idx2[:2]:
         
     | 
| 117 | 
         
            -
                    return overlapping
         
     | 
| 118 | 
         
            -
                if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
         
     | 
| 119 | 
         
            -
                    overlapping = False
         
     | 
| 120 | 
         
            -
                return overlapping
         
     | 
| 121 | 
         
            -
             
     | 
| 122 | 
         
            -
             
     | 
| 123 | 
         
            -
            def has_overlapping_nested(idx1, idx2):
         
     | 
| 124 | 
         
            -
                # Return True if idx1 and idx2 overlap, but neither is nested inside the other
         
     | 
| 125 | 
         
            -
                if idx1[:2] == idx2[:2]:
         
     | 
| 126 | 
         
            -
                    return True
         
     | 
| 127 | 
         
            -
                if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
         
     | 
| 128 | 
         
            -
                    return False
         
     | 
| 129 | 
         
            -
                else:
         
     | 
| 130 | 
         
            -
                    return True
         
     | 
| 131 | 
         
            -
             
     | 
| 132 | 
         
            -
             
     | 
| 133 | 
         
            -
            def greedy_search(spans, flat_ner=True):  # start, end, class, score
         
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         
            -
                if flat_ner:
         
     | 
| 136 | 
         
            -
                    has_ov = has_overlapping
         
     | 
| 137 | 
         
            -
                else:
         
     | 
| 138 | 
         
            -
                    has_ov = has_overlapping_nested
         
     | 
| 139 | 
         
            -
             
     | 
| 140 | 
         
            -
                new_list = []
         
     | 
| 141 | 
         
            -
                span_prob = sorted(spans, key=lambda x: -x[-1])
         
     | 
| 142 | 
         
            -
                for i in range(len(spans)):
         
     | 
| 143 | 
         
            -
                    b = span_prob[i]
         
     | 
| 144 | 
         
            -
                    flag = False
         
     | 
| 145 | 
         
            -
                    for new in new_list:
         
     | 
| 146 | 
         
            -
                        if has_ov(b[:-1], new):
         
     | 
| 147 | 
         
            -
                            flag = True
         
     | 
| 148 | 
         
            -
                            break
         
     | 
| 149 | 
         
            -
                    if not flag:
         
     | 
| 150 | 
         
            -
                        new_list.append(b[:-1])
         
     | 
| 151 | 
         
            -
                new_list = sorted(new_list, key=lambda x: x[0])
         
     | 
| 152 | 
         
            -
                return new_list
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/modules/layers.py
    DELETED
    
    | 
         @@ -1,28 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import torch
         
     | 
| 2 | 
         
            -
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            -
            from torch import nn
         
     | 
| 4 | 
         
            -
            from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            class LstmSeq2SeqEncoder(nn.Module):
         
     | 
| 8 | 
         
            -
                def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
         
     | 
| 9 | 
         
            -
                    super(LstmSeq2SeqEncoder, self).__init__()
         
     | 
| 10 | 
         
            -
                    self.lstm = nn.LSTM(input_size=input_size,
         
     | 
| 11 | 
         
            -
                                        hidden_size=hidden_size,
         
     | 
| 12 | 
         
            -
                                        num_layers=num_layers,
         
     | 
| 13 | 
         
            -
                                        dropout=dropout,
         
     | 
| 14 | 
         
            -
                                        bidirectional=bidirectional,
         
     | 
| 15 | 
         
            -
                                        batch_first=True)
         
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
                def forward(self, x, mask, hidden=None):
         
     | 
| 18 | 
         
            -
                    # Packing the input sequence
         
     | 
| 19 | 
         
            -
                    lengths = mask.sum(dim=1).cpu()
         
     | 
| 20 | 
         
            -
                    packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
         
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
                    # Passing packed sequence through LSTM
         
     | 
| 23 | 
         
            -
                    packed_output, hidden = self.lstm(packed_x, hidden)
         
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
                    # Unpacking the output sequence
         
     | 
| 26 | 
         
            -
                    output, _ = pad_packed_sequence(packed_output, batch_first=True)
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
                    return output
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/modules/run_evaluation.py
    DELETED
    
    | 
         @@ -1,188 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import glob
         
     | 
| 2 | 
         
            -
            import json
         
     | 
| 3 | 
         
            -
            import os
         
     | 
| 4 | 
         
            -
            import os
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            import torch
         
     | 
| 7 | 
         
            -
            from tqdm import tqdm
         
     | 
| 8 | 
         
            -
            import random
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            def open_content(path):
         
     | 
| 12 | 
         
            -
                paths = glob.glob(os.path.join(path, "*.json"))
         
     | 
| 13 | 
         
            -
                train, dev, test, labels = None, None, None, None
         
     | 
| 14 | 
         
            -
                for p in paths:
         
     | 
| 15 | 
         
            -
                    if "train" in p:
         
     | 
| 16 | 
         
            -
                        with open(p, "r") as f:
         
     | 
| 17 | 
         
            -
                            train = json.load(f)
         
     | 
| 18 | 
         
            -
                    elif "dev" in p:
         
     | 
| 19 | 
         
            -
                        with open(p, "r") as f:
         
     | 
| 20 | 
         
            -
                            dev = json.load(f)
         
     | 
| 21 | 
         
            -
                    elif "test" in p:
         
     | 
| 22 | 
         
            -
                        with open(p, "r") as f:
         
     | 
| 23 | 
         
            -
                            test = json.load(f)
         
     | 
| 24 | 
         
            -
                    elif "labels" in p:
         
     | 
| 25 | 
         
            -
                        with open(p, "r") as f:
         
     | 
| 26 | 
         
            -
                            labels = json.load(f)
         
     | 
| 27 | 
         
            -
                return train, dev, test, labels
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
            def process(data):
         
     | 
| 31 | 
         
            -
                words = data['sentence'].split()
         
     | 
| 32 | 
         
            -
                entities = []  # List of entities (start, end, type)
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
                for entity in data['entities']:
         
     | 
| 35 | 
         
            -
                    start_char, end_char = entity['pos']
         
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
                    # Initialize variables to keep track of word positions
         
     | 
| 38 | 
         
            -
                    start_word = None
         
     | 
| 39 | 
         
            -
                    end_word = None
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
                    # Iterate through words and find the word positions
         
     | 
| 42 | 
         
            -
                    char_count = 0
         
     | 
| 43 | 
         
            -
                    for i, word in enumerate(words):
         
     | 
| 44 | 
         
            -
                        word_length = len(word)
         
     | 
| 45 | 
         
            -
                        if char_count == start_char:
         
     | 
| 46 | 
         
            -
                            start_word = i
         
     | 
| 47 | 
         
            -
                        if char_count + word_length == end_char:
         
     | 
| 48 | 
         
            -
                            end_word = i
         
     | 
| 49 | 
         
            -
                            break
         
     | 
| 50 | 
         
            -
                        char_count += word_length + 1  # Add 1 for the space
         
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
                    # Append the word positions to the list
         
     | 
| 53 | 
         
            -
                    entities.append((start_word, end_word, entity['type']))
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
                # Create a list of word positions for each entity
         
     | 
| 56 | 
         
            -
                sample = {
         
     | 
| 57 | 
         
            -
                    "tokenized_text": words,
         
     | 
| 58 | 
         
            -
                    "ner": entities
         
     | 
| 59 | 
         
            -
                }
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
                return sample
         
     | 
| 62 | 
         
            -
             
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
            # create dataset
         
     | 
| 65 | 
         
            -
            def create_dataset(path):
         
     | 
| 66 | 
         
            -
                train, dev, test, labels = open_content(path)
         
     | 
| 67 | 
         
            -
                train_dataset = []
         
     | 
| 68 | 
         
            -
                dev_dataset = []
         
     | 
| 69 | 
         
            -
                test_dataset = []
         
     | 
| 70 | 
         
            -
                for data in train:
         
     | 
| 71 | 
         
            -
                    train_dataset.append(process(data))
         
     | 
| 72 | 
         
            -
                for data in dev:
         
     | 
| 73 | 
         
            -
                    dev_dataset.append(process(data))
         
     | 
| 74 | 
         
            -
                for data in test:
         
     | 
| 75 | 
         
            -
                    test_dataset.append(process(data))
         
     | 
| 76 | 
         
            -
                return train_dataset, dev_dataset, test_dataset, labels
         
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
            @torch.no_grad()
         
     | 
| 80 | 
         
            -
            def get_for_one_path(path, model):
         
     | 
| 81 | 
         
            -
                # load the dataset
         
     | 
| 82 | 
         
            -
                _, _, test_dataset, entity_types = create_dataset(path)
         
     | 
| 83 | 
         
            -
             
     | 
| 84 | 
         
            -
                data_name = path.split("/")[-1]  # get the name of the dataset
         
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
                # check if the dataset is flat_ner
         
     | 
| 87 | 
         
            -
                flat_ner = True
         
     | 
| 88 | 
         
            -
                if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
         
     | 
| 89 | 
         
            -
                    flat_ner = False
         
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
                # evaluate the model
         
     | 
| 92 | 
         
            -
                results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
         
     | 
| 93 | 
         
            -
                                             entity_types=entity_types)
         
     | 
| 94 | 
         
            -
                return data_name, results, f1
         
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
            -
            def get_for_all_path(model, steps, log_dir, data_paths):
         
     | 
| 98 | 
         
            -
                all_paths = glob.glob(f"{data_paths}/*")
         
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
                all_paths = sorted(all_paths)
         
     | 
| 101 | 
         
            -
             
     | 
| 102 | 
         
            -
                # move the model to the device
         
     | 
| 103 | 
         
            -
                device = next(model.parameters()).device
         
     | 
| 104 | 
         
            -
                model.to(device)
         
     | 
| 105 | 
         
            -
                # set the model to eval mode
         
     | 
| 106 | 
         
            -
                model.eval()
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
                # log the results
         
     | 
| 109 | 
         
            -
                save_path = os.path.join(log_dir, "results.txt")
         
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
                with open(save_path, "a") as f:
         
     | 
| 112 | 
         
            -
                    f.write("##############################################\n")
         
     | 
| 113 | 
         
            -
                    # write step
         
     | 
| 114 | 
         
            -
                    f.write("step: " + str(steps) + "\n")
         
     | 
| 115 | 
         
            -
             
     | 
| 116 | 
         
            -
                zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
         
     | 
| 117 | 
         
            -
                                  "CrossNER_politics", "CrossNER_science"]
         
     | 
| 118 | 
         
            -
             
     | 
| 119 | 
         
            -
                zero_shot_benc_results = {}
         
     | 
| 120 | 
         
            -
                all_results = {}  # without crossNER
         
     | 
| 121 | 
         
            -
             
     | 
| 122 | 
         
            -
                for p in tqdm(all_paths):
         
     | 
| 123 | 
         
            -
                    if "sample_" not in p:
         
     | 
| 124 | 
         
            -
                        data_name, results, f1 = get_for_one_path(p, model)
         
     | 
| 125 | 
         
            -
                        # write to file
         
     | 
| 126 | 
         
            -
                        with open(save_path, "a") as f:
         
     | 
| 127 | 
         
            -
                            f.write(data_name + "\n")
         
     | 
| 128 | 
         
            -
                            f.write(str(results) + "\n")
         
     | 
| 129 | 
         
            -
             
     | 
| 130 | 
         
            -
                        if data_name in zero_shot_benc:
         
     | 
| 131 | 
         
            -
                            zero_shot_benc_results[data_name] = f1
         
     | 
| 132 | 
         
            -
                        else:
         
     | 
| 133 | 
         
            -
                            all_results[data_name] = f1
         
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         
            -
                avg_all = sum(all_results.values()) / len(all_results)
         
     | 
| 136 | 
         
            -
                avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
         
     | 
| 137 | 
         
            -
             
     | 
| 138 | 
         
            -
                save_path_table = os.path.join(log_dir, "tables.txt")
         
     | 
| 139 | 
         
            -
             
     | 
| 140 | 
         
            -
                # results for all datasets except crossNER
         
     | 
| 141 | 
         
            -
                table_bench_all = ""
         
     | 
| 142 | 
         
            -
                for k, v in all_results.items():
         
     | 
| 143 | 
         
            -
                    table_bench_all += f"{k:20}: {v:.1%}\n"
         
     | 
| 144 | 
         
            -
                # (20 size aswell for average i.e. :20)
         
     | 
| 145 | 
         
            -
                table_bench_all += f"{'Average':20}: {avg_all:.1%}"
         
     | 
| 146 | 
         
            -
             
     | 
| 147 | 
         
            -
                # results for zero-shot benchmark
         
     | 
| 148 | 
         
            -
                table_bench_zeroshot = ""
         
     | 
| 149 | 
         
            -
                for k, v in zero_shot_benc_results.items():
         
     | 
| 150 | 
         
            -
                    table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
         
     | 
| 151 | 
         
            -
                table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
         
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
                # write to file
         
     | 
| 154 | 
         
            -
                with open(save_path_table, "a") as f:
         
     | 
| 155 | 
         
            -
                    f.write("##############################################\n")
         
     | 
| 156 | 
         
            -
                    f.write("step: " + str(steps) + "\n")
         
     | 
| 157 | 
         
            -
                    f.write("Table for all datasets except crossNER\n")
         
     | 
| 158 | 
         
            -
                    f.write(table_bench_all + "\n\n")
         
     | 
| 159 | 
         
            -
                    f.write("Table for zero-shot benchmark\n")
         
     | 
| 160 | 
         
            -
                    f.write(table_bench_zeroshot + "\n")
         
     | 
| 161 | 
         
            -
                    f.write("##############################################\n\n")
         
     | 
| 162 | 
         
            -
             
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
            def sample_train_data(data_paths, sample_size=10000):
         
     | 
| 165 | 
         
            -
                all_paths = glob.glob(f"{data_paths}/*")
         
     | 
| 166 | 
         
            -
             
     | 
| 167 | 
         
            -
                all_paths = sorted(all_paths)
         
     | 
| 168 | 
         
            -
             
     | 
| 169 | 
         
            -
                # to exclude the zero-shot benchmark datasets
         
     | 
| 170 | 
         
            -
                zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
         
     | 
| 171 | 
         
            -
                                  "CrossNER_politics", "CrossNER_science", "ACE 2004"]
         
     | 
| 172 | 
         
            -
             
     | 
| 173 | 
         
            -
                new_train = []
         
     | 
| 174 | 
         
            -
                # take 10k samples from each dataset
         
     | 
| 175 | 
         
            -
                for p in tqdm(all_paths):
         
     | 
| 176 | 
         
            -
                    if any([i in p for i in zero_shot_benc]):
         
     | 
| 177 | 
         
            -
                        continue
         
     | 
| 178 | 
         
            -
                    train, dev, test, labels = create_dataset(p)
         
     | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
            -
                    # add label key to the train data
         
     | 
| 181 | 
         
            -
                    for i in range(len(train)):
         
     | 
| 182 | 
         
            -
                        train[i]["label"] = labels
         
     | 
| 183 | 
         
            -
             
     | 
| 184 | 
         
            -
                    random.shuffle(train)
         
     | 
| 185 | 
         
            -
                    train = train[:sample_size]
         
     | 
| 186 | 
         
            -
                    new_train.extend(train)
         
     | 
| 187 | 
         
            -
             
     | 
| 188 | 
         
            -
                return new_train
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/modules/span_rep.py
    DELETED
    
    | 
         @@ -1,326 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import torch
         
     | 
| 2 | 
         
            -
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            -
            from torch import nn
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class SpanQuery(nn.Module):
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
                def __init__(self, hidden_size, max_width, trainable=True):
         
     | 
| 9 | 
         
            -
                    super().__init__()
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
                    self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
                    nn.init.uniform_(self.query_seg, a=-1, b=1)
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
                    if not trainable:
         
     | 
| 16 | 
         
            -
                        self.query_seg.requires_grad = False
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
                    self.project = nn.Sequential(
         
     | 
| 19 | 
         
            -
                        nn.Linear(hidden_size, hidden_size),
         
     | 
| 20 | 
         
            -
                        nn.ReLU()
         
     | 
| 21 | 
         
            -
                    )
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
                def forward(self, h, *args):
         
     | 
| 24 | 
         
            -
                    # h of shape [B, L, D]
         
     | 
| 25 | 
         
            -
                    # query_seg of shape [D, max_width]
         
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
                    span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
                    return self.project(span_rep)
         
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
            class SpanMLP(nn.Module):
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
                def __init__(self, hidden_size, max_width):
         
     | 
| 35 | 
         
            -
                    super().__init__()
         
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
                    self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
         
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
            -
                def forward(self, h, *args):
         
     | 
| 40 | 
         
            -
                    # h of shape [B, L, D]
         
     | 
| 41 | 
         
            -
                    # query_seg of shape [D, max_width]
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                    B, L, D = h.size()
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
                    span_rep = self.mlp(h)
         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
                    span_rep = span_rep.view(B, L, -1, D)
         
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
                    return span_rep.relu()
         
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
            class SpanCAT(nn.Module):
         
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
                def __init__(self, hidden_size, max_width):
         
     | 
| 55 | 
         
            -
                    super().__init__()
         
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                    self.max_width = max_width
         
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
                    self.query_seg = nn.Parameter(torch.randn(128, max_width))
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
                    self.project = nn.Sequential(
         
     | 
| 62 | 
         
            -
                        nn.Linear(hidden_size + 128, hidden_size),
         
     | 
| 63 | 
         
            -
                        nn.ReLU()
         
     | 
| 64 | 
         
            -
                    )
         
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
                def forward(self, h, *args):
         
     | 
| 67 | 
         
            -
                    # h of shape [B, L, D]
         
     | 
| 68 | 
         
            -
                    # query_seg of shape [D, max_width]
         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
                    B, L, D = h.size()
         
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
                    h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
         
     | 
| 73 | 
         
            -
             
     | 
| 74 | 
         
            -
                    q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
                    span_rep = torch.cat([h, q], dim=-1)
         
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
                    span_rep = self.project(span_rep)
         
     | 
| 79 | 
         
            -
             
     | 
| 80 | 
         
            -
                    return span_rep
         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
            class SpanConvBlock(nn.Module):
         
     | 
| 84 | 
         
            -
                def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
         
     | 
| 85 | 
         
            -
                    super().__init__()
         
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
                    if span_mode == 'conv_conv':
         
     | 
| 88 | 
         
            -
                        self.conv = nn.Conv1d(hidden_size, hidden_size,
         
     | 
| 89 | 
         
            -
                                              kernel_size=kernel_size)
         
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
                        # initialize the weights
         
     | 
| 92 | 
         
            -
                        nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
         
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
            -
                    elif span_mode == 'conv_max':
         
     | 
| 95 | 
         
            -
                        self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
         
     | 
| 96 | 
         
            -
                    elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
         
     | 
| 97 | 
         
            -
                        self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
         
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
                    self.span_mode = span_mode
         
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
                    self.pad = kernel_size - 1
         
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
                def forward(self, x):
         
     | 
| 104 | 
         
            -
             
     | 
| 105 | 
         
            -
                    x = torch.einsum('bld->bdl', x)
         
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
                    if self.pad > 0:
         
     | 
| 108 | 
         
            -
                        x = F.pad(x, (0, self.pad), "constant", 0)
         
     | 
| 109 | 
         
            -
             
     | 
| 110 | 
         
            -
                    x = self.conv(x)
         
     | 
| 111 | 
         
            -
             
     | 
| 112 | 
         
            -
                    if self.span_mode == "conv_sum":
         
     | 
| 113 | 
         
            -
                        x = x * (self.pad + 1)
         
     | 
| 114 | 
         
            -
             
     | 
| 115 | 
         
            -
                    return torch.einsum('bdl->bld', x)
         
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
             
     | 
| 118 | 
         
            -
            class SpanConv(nn.Module):
         
     | 
| 119 | 
         
            -
                def __init__(self, hidden_size, max_width, span_mode):
         
     | 
| 120 | 
         
            -
                    super().__init__()
         
     | 
| 121 | 
         
            -
             
     | 
| 122 | 
         
            -
                    kernels = [i + 2 for i in range(max_width - 1)]
         
     | 
| 123 | 
         
            -
             
     | 
| 124 | 
         
            -
                    self.convs = nn.ModuleList()
         
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
            -
                    for kernel in kernels:
         
     | 
| 127 | 
         
            -
                        self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
         
     | 
| 128 | 
         
            -
             
     | 
| 129 | 
         
            -
                    self.project = nn.Sequential(
         
     | 
| 130 | 
         
            -
                        nn.ReLU(),
         
     | 
| 131 | 
         
            -
                        nn.Linear(hidden_size, hidden_size)
         
     | 
| 132 | 
         
            -
                    )
         
     | 
| 133 | 
         
            -
             
     | 
| 134 | 
         
            -
                def forward(self, x, *args):
         
     | 
| 135 | 
         
            -
             
     | 
| 136 | 
         
            -
                    span_reps = [x]
         
     | 
| 137 | 
         
            -
             
     | 
| 138 | 
         
            -
                    for conv in self.convs:
         
     | 
| 139 | 
         
            -
                        h = conv(x)
         
     | 
| 140 | 
         
            -
                        span_reps.append(h)
         
     | 
| 141 | 
         
            -
             
     | 
| 142 | 
         
            -
                    span_reps = torch.stack(span_reps, dim=-2)
         
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
                    return self.project(span_reps)
         
     | 
| 145 | 
         
            -
             
     | 
| 146 | 
         
            -
             
     | 
| 147 | 
         
            -
            class SpanEndpointsBlock(nn.Module):
         
     | 
| 148 | 
         
            -
                def __init__(self, kernel_size):
         
     | 
| 149 | 
         
            -
                    super().__init__()
         
     | 
| 150 | 
         
            -
             
     | 
| 151 | 
         
            -
                    self.kernel_size = kernel_size
         
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
                def forward(self, x):
         
     | 
| 154 | 
         
            -
                    B, L, D = x.size()
         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                    span_idx = torch.LongTensor(
         
     | 
| 157 | 
         
            -
                        [[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
         
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
                    x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
         
     | 
| 160 | 
         
            -
             
     | 
| 161 | 
         
            -
                    # endrep
         
     | 
| 162 | 
         
            -
                    start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
         
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
                    start_end_rep = start_end_rep.view(B, L, 2, D)
         
     | 
| 165 | 
         
            -
             
     | 
| 166 | 
         
            -
                    return start_end_rep
         
     | 
| 167 | 
         
            -
             
     | 
| 168 | 
         
            -
             
     | 
| 169 | 
         
            -
            class ConvShare(nn.Module):
         
     | 
| 170 | 
         
            -
                def __init__(self, hidden_size, max_width):
         
     | 
| 171 | 
         
            -
                    super().__init__()
         
     | 
| 172 | 
         
            -
             
     | 
| 173 | 
         
            -
                    self.max_width = max_width
         
     | 
| 174 | 
         
            -
             
     | 
| 175 | 
         
            -
                    self.conv_weigth = nn.Parameter(
         
     | 
| 176 | 
         
            -
                        torch.randn(hidden_size, hidden_size, max_width))
         
     | 
| 177 | 
         
            -
             
     | 
| 178 | 
         
            -
                    nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
         
     | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
            -
                    self.project = nn.Sequential(
         
     | 
| 181 | 
         
            -
                        nn.ReLU(),
         
     | 
| 182 | 
         
            -
                        nn.Linear(hidden_size, hidden_size)
         
     | 
| 183 | 
         
            -
                    )
         
     | 
| 184 | 
         
            -
             
     | 
| 185 | 
         
            -
                def forward(self, x, *args):
         
     | 
| 186 | 
         
            -
                    span_reps = []
         
     | 
| 187 | 
         
            -
             
     | 
| 188 | 
         
            -
                    x = torch.einsum('bld->bdl', x)
         
     | 
| 189 | 
         
            -
             
     | 
| 190 | 
         
            -
                    for i in range(self.max_width):
         
     | 
| 191 | 
         
            -
                        pad = i
         
     | 
| 192 | 
         
            -
                        x_i = F.pad(x, (0, pad), "constant", 0)
         
     | 
| 193 | 
         
            -
                        conv_w = self.conv_weigth[:, :, :i + 1]
         
     | 
| 194 | 
         
            -
                        out_i = F.conv1d(x_i, conv_w)
         
     | 
| 195 | 
         
            -
                        span_reps.append(out_i.transpose(-1, -2))
         
     | 
| 196 | 
         
            -
             
     | 
| 197 | 
         
            -
                    out = torch.stack(span_reps, dim=-2)
         
     | 
| 198 | 
         
            -
             
     | 
| 199 | 
         
            -
                    return self.project(out)
         
     | 
| 200 | 
         
            -
             
     | 
| 201 | 
         
            -
             
     | 
| 202 | 
         
            -
            def extract_elements(sequence, indices):
         
     | 
| 203 | 
         
            -
                B, L, D = sequence.shape
         
     | 
| 204 | 
         
            -
                K = indices.shape[1]
         
     | 
| 205 | 
         
            -
             
     | 
| 206 | 
         
            -
                # Expand indices to [B, K, D]
         
     | 
| 207 | 
         
            -
                expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
         
     | 
| 208 | 
         
            -
             
     | 
| 209 | 
         
            -
                # Gather the elements
         
     | 
| 210 | 
         
            -
                extracted_elements = torch.gather(sequence, 1, expanded_indices)
         
     | 
| 211 | 
         
            -
             
     | 
| 212 | 
         
            -
                return extracted_elements
         
     | 
| 213 | 
         
            -
             
     | 
| 214 | 
         
            -
             
     | 
| 215 | 
         
            -
            class SpanMarker(nn.Module):
         
     | 
| 216 | 
         
            -
             
     | 
| 217 | 
         
            -
                def __init__(self, hidden_size, max_width, dropout=0.4):
         
     | 
| 218 | 
         
            -
                    super().__init__()
         
     | 
| 219 | 
         
            -
             
     | 
| 220 | 
         
            -
                    self.max_width = max_width
         
     | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
            -
                    self.project_start = nn.Sequential(
         
     | 
| 223 | 
         
            -
                        nn.Linear(hidden_size, hidden_size * 2, bias=True),
         
     | 
| 224 | 
         
            -
                        nn.ReLU(),
         
     | 
| 225 | 
         
            -
                        nn.Dropout(dropout),
         
     | 
| 226 | 
         
            -
                        nn.Linear(hidden_size * 2, hidden_size, bias=True),
         
     | 
| 227 | 
         
            -
                    )
         
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
                    self.project_end = nn.Sequential(
         
     | 
| 230 | 
         
            -
                        nn.Linear(hidden_size, hidden_size * 2, bias=True),
         
     | 
| 231 | 
         
            -
                        nn.ReLU(),
         
     | 
| 232 | 
         
            -
                        nn.Dropout(dropout),
         
     | 
| 233 | 
         
            -
                        nn.Linear(hidden_size * 2, hidden_size, bias=True),
         
     | 
| 234 | 
         
            -
                    )
         
     | 
| 235 | 
         
            -
             
     | 
| 236 | 
         
            -
                    self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
         
     | 
| 237 | 
         
            -
             
     | 
| 238 | 
         
            -
                def forward(self, h, span_idx):
         
     | 
| 239 | 
         
            -
                    # h of shape [B, L, D]
         
     | 
| 240 | 
         
            -
                    # query_seg of shape [D, max_width]
         
     | 
| 241 | 
         
            -
             
     | 
| 242 | 
         
            -
                    B, L, D = h.size()
         
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
                    # project start and end
         
     | 
| 245 | 
         
            -
                    start_rep = self.project_start(h)
         
     | 
| 246 | 
         
            -
                    end_rep = self.project_end(h)
         
     | 
| 247 | 
         
            -
             
     | 
| 248 | 
         
            -
                    start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
         
     | 
| 249 | 
         
            -
                    end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
         
     | 
| 250 | 
         
            -
             
     | 
| 251 | 
         
            -
                    # concat start and end
         
     | 
| 252 | 
         
            -
                    cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
         
     | 
| 253 | 
         
            -
             
     | 
| 254 | 
         
            -
                    # project
         
     | 
| 255 | 
         
            -
                    cat = self.out_project(cat)
         
     | 
| 256 | 
         
            -
             
     | 
| 257 | 
         
            -
                    # reshape
         
     | 
| 258 | 
         
            -
                    return cat.view(B, L, self.max_width, D)
         
     | 
| 259 | 
         
            -
             
     | 
| 260 | 
         
            -
             
     | 
| 261 | 
         
            -
            class ConvShareV2(nn.Module):
         
     | 
| 262 | 
         
            -
                def __init__(self, hidden_size, max_width):
         
     | 
| 263 | 
         
            -
                    super().__init__()
         
     | 
| 264 | 
         
            -
             
     | 
| 265 | 
         
            -
                    self.max_width = max_width
         
     | 
| 266 | 
         
            -
             
     | 
| 267 | 
         
            -
                    self.conv_weigth = nn.Parameter(
         
     | 
| 268 | 
         
            -
                        torch.randn(hidden_size, hidden_size, max_width)
         
     | 
| 269 | 
         
            -
                    )
         
     | 
| 270 | 
         
            -
             
     | 
| 271 | 
         
            -
                    nn.init.xavier_normal_(self.conv_weigth)
         
     | 
| 272 | 
         
            -
             
     | 
| 273 | 
         
            -
                def forward(self, x, *args):
         
     | 
| 274 | 
         
            -
                    span_reps = []
         
     | 
| 275 | 
         
            -
             
     | 
| 276 | 
         
            -
                    x = torch.einsum('bld->bdl', x)
         
     | 
| 277 | 
         
            -
             
     | 
| 278 | 
         
            -
                    for i in range(self.max_width):
         
     | 
| 279 | 
         
            -
                        pad = i
         
     | 
| 280 | 
         
            -
                        x_i = F.pad(x, (0, pad), "constant", 0)
         
     | 
| 281 | 
         
            -
                        conv_w = self.conv_weigth[:, :, :i + 1]
         
     | 
| 282 | 
         
            -
                        out_i = F.conv1d(x_i, conv_w)
         
     | 
| 283 | 
         
            -
                        span_reps.append(out_i.transpose(-1, -2))
         
     | 
| 284 | 
         
            -
             
     | 
| 285 | 
         
            -
                    out = torch.stack(span_reps, dim=-2)
         
     | 
| 286 | 
         
            -
             
     | 
| 287 | 
         
            -
                    return out
         
     | 
| 288 | 
         
            -
             
     | 
| 289 | 
         
            -
             
     | 
| 290 | 
         
            -
            class SpanRepLayer(nn.Module):
         
     | 
| 291 | 
         
            -
                """
         
     | 
| 292 | 
         
            -
                Various span representation approaches
         
     | 
| 293 | 
         
            -
                """
         
     | 
| 294 | 
         
            -
             
     | 
| 295 | 
         
            -
                def __init__(self, hidden_size, max_width, span_mode, **kwargs):
         
     | 
| 296 | 
         
            -
                    super().__init__()
         
     | 
| 297 | 
         
            -
             
     | 
| 298 | 
         
            -
                    if span_mode == 'marker':
         
     | 
| 299 | 
         
            -
                        self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
         
     | 
| 300 | 
         
            -
                    elif span_mode == 'query':
         
     | 
| 301 | 
         
            -
                        self.span_rep_layer = SpanQuery(
         
     | 
| 302 | 
         
            -
                            hidden_size, max_width, trainable=True)
         
     | 
| 303 | 
         
            -
                    elif span_mode == 'mlp':
         
     | 
| 304 | 
         
            -
                        self.span_rep_layer = SpanMLP(hidden_size, max_width)
         
     | 
| 305 | 
         
            -
                    elif span_mode == 'cat':
         
     | 
| 306 | 
         
            -
                        self.span_rep_layer = SpanCAT(hidden_size, max_width)
         
     | 
| 307 | 
         
            -
                    elif span_mode == 'conv_conv':
         
     | 
| 308 | 
         
            -
                        self.span_rep_layer = SpanConv(
         
     | 
| 309 | 
         
            -
                            hidden_size, max_width, span_mode='conv_conv')
         
     | 
| 310 | 
         
            -
                    elif span_mode == 'conv_max':
         
     | 
| 311 | 
         
            -
                        self.span_rep_layer = SpanConv(
         
     | 
| 312 | 
         
            -
                            hidden_size, max_width, span_mode='conv_max')
         
     | 
| 313 | 
         
            -
                    elif span_mode == 'conv_mean':
         
     | 
| 314 | 
         
            -
                        self.span_rep_layer = SpanConv(
         
     | 
| 315 | 
         
            -
                            hidden_size, max_width, span_mode='conv_mean')
         
     | 
| 316 | 
         
            -
                    elif span_mode == 'conv_sum':
         
     | 
| 317 | 
         
            -
                        self.span_rep_layer = SpanConv(
         
     | 
| 318 | 
         
            -
                            hidden_size, max_width, span_mode='conv_sum')
         
     | 
| 319 | 
         
            -
                    elif span_mode == 'conv_share':
         
     | 
| 320 | 
         
            -
                        self.span_rep_layer = ConvShare(hidden_size, max_width)
         
     | 
| 321 | 
         
            -
                    else:
         
     | 
| 322 | 
         
            -
                        raise ValueError(f'Unknown span mode {span_mode}')
         
     | 
| 323 | 
         
            -
             
     | 
| 324 | 
         
            -
                def forward(self, x, *args):
         
     | 
| 325 | 
         
            -
             
     | 
| 326 | 
         
            -
                    return self.span_rep_layer(x, *args)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/modules/token_rep.py
    DELETED
    
    | 
         @@ -1,54 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from typing import List
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            import torch
         
     | 
| 4 | 
         
            -
            from flair.data import Sentence
         
     | 
| 5 | 
         
            -
            from flair.embeddings import TransformerWordEmbeddings
         
     | 
| 6 | 
         
            -
            from torch import nn
         
     | 
| 7 | 
         
            -
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
            # flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
         
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
            class TokenRepLayer(nn.Module):
         
     | 
| 14 | 
         
            -
                def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
         
     | 
| 15 | 
         
            -
                             hidden_size: int = 768,
         
     | 
| 16 | 
         
            -
                             add_tokens=["[SEP]", "[ENT]"]
         
     | 
| 17 | 
         
            -
                             ):
         
     | 
| 18 | 
         
            -
                    super().__init__()
         
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
                    self.bert_layer = TransformerWordEmbeddings(
         
     | 
| 21 | 
         
            -
                        model_name,
         
     | 
| 22 | 
         
            -
                        fine_tune=fine_tune,
         
     | 
| 23 | 
         
            -
                        subtoken_pooling=subtoken_pooling,
         
     | 
| 24 | 
         
            -
                        allow_long_sentences=True
         
     | 
| 25 | 
         
            -
                    )
         
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
                    # add tokens to vocabulary
         
     | 
| 28 | 
         
            -
                    self.bert_layer.tokenizer.add_tokens(add_tokens)
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
                    # resize token embeddings
         
     | 
| 31 | 
         
            -
                    self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
         
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
                    bert_hidden_size = self.bert_layer.embedding_length
         
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
                    if hidden_size != bert_hidden_size:
         
     | 
| 36 | 
         
            -
                        self.projection = nn.Linear(bert_hidden_size, hidden_size)
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
         
     | 
| 39 | 
         
            -
                    token_embeddings = self.compute_word_embedding(tokens)
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
                    if hasattr(self, "projection"):
         
     | 
| 42 | 
         
            -
                        token_embeddings = self.projection(token_embeddings)
         
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
                    B = len(lengths)
         
     | 
| 45 | 
         
            -
                    max_length = lengths.max()
         
     | 
| 46 | 
         
            -
                    mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
         
     | 
| 47 | 
         
            -
                        token_embeddings.device).long()
         
     | 
| 48 | 
         
            -
                    return {"embeddings": token_embeddings, "mask": mask}
         
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
                def compute_word_embedding(self, tokens):
         
     | 
| 51 | 
         
            -
                    sentences = [Sentence(i) for i in tokens]
         
     | 
| 52 | 
         
            -
                    self.bert_layer.embed(sentences)
         
     | 
| 53 | 
         
            -
                    token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
         
     | 
| 54 | 
         
            -
                    return token_embeddings
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/requirements.txt
    DELETED
    
    | 
         @@ -1,6 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            torch
         
     | 
| 2 | 
         
            -
            transformers
         
     | 
| 3 | 
         
            -
            huggingface_hub
         
     | 
| 4 | 
         
            -
            flair
         
     | 
| 5 | 
         
            -
            seqeval
         
     | 
| 6 | 
         
            -
            tqdm
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/save_load.py
    DELETED
    
    | 
         @@ -1,20 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import torch
         
     | 
| 2 | 
         
            -
            from model import GLiNER
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            def save_model(current_model, path):
         
     | 
| 6 | 
         
            -
                config = current_model.config
         
     | 
| 7 | 
         
            -
                dict_save = {"model_weights": current_model.state_dict(), "config": config}
         
     | 
| 8 | 
         
            -
                torch.save(dict_save, path)
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            def load_model(path, model_name=None, device=None):
         
     | 
| 12 | 
         
            -
                dict_load = torch.load(path, map_location=torch.device('cpu'))
         
     | 
| 13 | 
         
            -
                config = dict_load["config"]
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
                if model_name is not None:
         
     | 
| 16 | 
         
            -
                    config.model_name = model_name
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
                loaded_model = GLiNER(config)
         
     | 
| 19 | 
         
            -
                loaded_model.load_state_dict(dict_load["model_weights"])
         
     | 
| 20 | 
         
            -
                return loaded_model.to(device) if device is not None else loaded_model
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        GLiNER/train.py
    DELETED
    
    | 
         @@ -1,131 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import argparse
         
     | 
| 2 | 
         
            -
            import os
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            import torch
         
     | 
| 5 | 
         
            -
            import yaml
         
     | 
| 6 | 
         
            -
            from tqdm import tqdm
         
     | 
| 7 | 
         
            -
            from transformers import get_cosine_schedule_with_warmup
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
            # from model_nested import NerFilteredSemiCRF
         
     | 
| 10 | 
         
            -
            from model import GLiNER
         
     | 
| 11 | 
         
            -
            from modules.run_evaluation import get_for_all_path, sample_train_data
         
     | 
| 12 | 
         
            -
            from save_load import save_model, load_model
         
     | 
| 13 | 
         
            -
            import json
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
            # train function
         
     | 
| 17 | 
         
            -
            def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
         
     | 
| 18 | 
         
            -
                      train_batch_size=8, device='cuda'):
         
     | 
| 19 | 
         
            -
                model.train()
         
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
                # initialize data loaders
         
     | 
| 22 | 
         
            -
                train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
                pbar = tqdm(range(num_steps))
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
                if warmup_ratio < 1:
         
     | 
| 27 | 
         
            -
                    num_warmup_steps = int(num_steps * warmup_ratio)
         
     | 
| 28 | 
         
            -
                else:
         
     | 
| 29 | 
         
            -
                    num_warmup_steps = int(warmup_ratio)
         
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
                scheduler = get_cosine_schedule_with_warmup(
         
     | 
| 32 | 
         
            -
                    optimizer,
         
     | 
| 33 | 
         
            -
                    num_warmup_steps=num_warmup_steps,
         
     | 
| 34 | 
         
            -
                    num_training_steps=num_steps
         
     | 
| 35 | 
         
            -
                )
         
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
                iter_train_loader = iter(train_loader)
         
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
            -
                for step in pbar:
         
     | 
| 40 | 
         
            -
                    try:
         
     | 
| 41 | 
         
            -
                        x = next(iter_train_loader)
         
     | 
| 42 | 
         
            -
                    except StopIteration:
         
     | 
| 43 | 
         
            -
                        iter_train_loader = iter(train_loader)
         
     | 
| 44 | 
         
            -
                        x = next(iter_train_loader)
         
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
                    for k, v in x.items():
         
     | 
| 47 | 
         
            -
                        if isinstance(v, torch.Tensor):
         
     | 
| 48 | 
         
            -
                            x[k] = v.to(device)
         
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
                    try:
         
     | 
| 51 | 
         
            -
                        loss = model(x)  # Forward pass
         
     | 
| 52 | 
         
            -
                    except:
         
     | 
| 53 | 
         
            -
                        continue
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
                    # check if loss is nan
         
     | 
| 56 | 
         
            -
                    if torch.isnan(loss):
         
     | 
| 57 | 
         
            -
                        continue
         
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
                    loss.backward()  # Compute gradients
         
     | 
| 60 | 
         
            -
                    optimizer.step()  # Update parameters
         
     | 
| 61 | 
         
            -
                    scheduler.step()  # Update learning rate schedule
         
     | 
| 62 | 
         
            -
                    optimizer.zero_grad()  # Reset gradients
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
                    description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
         
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
                    if (step + 1) % eval_every == 0:
         
     | 
| 67 | 
         
            -
                        current_path = os.path.join(log_dir, f'model_{step + 1}')
         
     | 
| 68 | 
         
            -
                        save_model(model, current_path)
         
     | 
| 69 | 
         
            -
                        #val_data_dir =  "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
         
     | 
| 70 | 
         
            -
                        #get_for_all_path(model, step, log_dir, val_data_dir)  # you can remove this comment if you want to evaluate the model
         
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
                        model.train()
         
     | 
| 73 | 
         
            -
             
     | 
| 74 | 
         
            -
                    pbar.set_description(description)
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
            def create_parser():
         
     | 
| 78 | 
         
            -
                parser = argparse.ArgumentParser(description="Span-based NER")
         
     | 
| 79 | 
         
            -
                parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
         
     | 
| 80 | 
         
            -
                parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
         
     | 
| 81 | 
         
            -
                return parser
         
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
             
     | 
| 84 | 
         
            -
            def load_config_as_namespace(config_file):
         
     | 
| 85 | 
         
            -
                with open(config_file, 'r') as f:
         
     | 
| 86 | 
         
            -
                    config_dict = yaml.safe_load(f)
         
     | 
| 87 | 
         
            -
                return argparse.Namespace(**config_dict)
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
            if __name__ == "__main__":
         
     | 
| 91 | 
         
            -
                # parse args
         
     | 
| 92 | 
         
            -
                parser = create_parser()
         
     | 
| 93 | 
         
            -
                args = parser.parse_args()
         
     | 
| 94 | 
         
            -
             
     | 
| 95 | 
         
            -
                # load config
         
     | 
| 96 | 
         
            -
                config = load_config_as_namespace(args.config)
         
     | 
| 97 | 
         
            -
             
     | 
| 98 | 
         
            -
                config.log_dir = args.log_dir
         
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
                try:
         
     | 
| 101 | 
         
            -
                    with open(config.train_data, 'r') as f:
         
     | 
| 102 | 
         
            -
                        data = json.load(f)
         
     | 
| 103 | 
         
            -
                except:
         
     | 
| 104 | 
         
            -
                    data = sample_train_data(config.train_data, 10000)
         
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
                if config.prev_path != "none":
         
     | 
| 107 | 
         
            -
                    model = load_model(config.prev_path)
         
     | 
| 108 | 
         
            -
                    model.config = config
         
     | 
| 109 | 
         
            -
                else:
         
     | 
| 110 | 
         
            -
                    model = GLiNER(config)
         
     | 
| 111 | 
         
            -
             
     | 
| 112 | 
         
            -
                if torch.cuda.is_available():
         
     | 
| 113 | 
         
            -
                    model = model.cuda()
         
     | 
| 114 | 
         
            -
             
     | 
| 115 | 
         
            -
                lr_encoder = float(config.lr_encoder)
         
     | 
| 116 | 
         
            -
                lr_others = float(config.lr_others)
         
     | 
| 117 | 
         
            -
             
     | 
| 118 | 
         
            -
                optimizer = torch.optim.AdamW([
         
     | 
| 119 | 
         
            -
                    # encoder
         
     | 
| 120 | 
         
            -
                    {'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
         
     | 
| 121 | 
         
            -
                    {'params': model.rnn.parameters(), 'lr': lr_others},
         
     | 
| 122 | 
         
            -
                    # projection layers
         
     | 
| 123 | 
         
            -
                    {'params': model.span_rep_layer.parameters(), 'lr': lr_others},
         
     | 
| 124 | 
         
            -
                    {'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
         
     | 
| 125 | 
         
            -
                ])
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
         
     | 
| 128 | 
         
            -
             
     | 
| 129 | 
         
            -
                train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
         
     | 
| 130 | 
         
            -
                      log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
         
     | 
| 131 | 
         
            -
                      device=device)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,8 +1,5 @@ 
     | 
|
| 1 | 
         
             
            from typing import Dict, Union
         
     | 
| 2 | 
         
            -
            import  
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            sys.path.extend(["./GLiNER"])
         
     | 
| 5 | 
         
            -
            from GLiNER.model import GLiNER
         
     | 
| 6 | 
         
             
            import gradio as gr
         
     | 
| 7 | 
         | 
| 8 | 
         
             
            model = GLiNER.from_pretrained("urchade/gliner_base")
         
     | 
| 
         @@ -110,20 +107,18 @@ with gr.Blocks(title="GLiNER-base") as demo: 
     | 
|
| 110 | 
         
             
                    gr.Markdown(
         
     | 
| 111 | 
         
             
                        """
         
     | 
| 112 | 
         
             
                        ## Installation
         
     | 
| 113 | 
         
            -
                        To use this model, you must  
     | 
| 114 | 
         
             
                        ```
         
     | 
| 115 | 
         
            -
                        ! 
     | 
| 116 | 
         
            -
                        %cd GLiNER
         
     | 
| 117 | 
         
            -
                        !pip install -r requirements.txt
         
     | 
| 118 | 
         
             
                        ```
         
     | 
| 119 | 
         | 
| 120 | 
         
             
                        ## Usage
         
     | 
| 121 | 
         
            -
                        Once you've downloaded the GLiNER  
     | 
| 122 | 
         
             
                        """
         
     | 
| 123 | 
         
             
                    )
         
     | 
| 124 | 
         
             
                    gr.Code(
         
     | 
| 125 | 
         
             
                        '''
         
     | 
| 126 | 
         
            -
            from  
     | 
| 127 | 
         | 
| 128 | 
         
             
            model = GLiNER.from_pretrained("urchade/gliner_base")
         
     | 
| 129 | 
         | 
| 
         | 
|
| 1 | 
         
             
            from typing import Dict, Union
         
     | 
| 2 | 
         
            +
            from gliner import GLiNER
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 3 | 
         
             
            import gradio as gr
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            model = GLiNER.from_pretrained("urchade/gliner_base")
         
     | 
| 
         | 
|
| 107 | 
         
             
                    gr.Markdown(
         
     | 
| 108 | 
         
             
                        """
         
     | 
| 109 | 
         
             
                        ## Installation
         
     | 
| 110 | 
         
            +
                        To use this model, you must install the GLiNER Python library:
         
     | 
| 111 | 
         
             
                        ```
         
     | 
| 112 | 
         
            +
                        !pip install gliner
         
     | 
| 
         | 
|
| 
         | 
|
| 113 | 
         
             
                        ```
         
     | 
| 114 | 
         | 
| 115 | 
         
             
                        ## Usage
         
     | 
| 116 | 
         
            +
                        Once you've downloaded the GLiNER library, you can import the GLiNER class. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
         
     | 
| 117 | 
         
             
                        """
         
     | 
| 118 | 
         
             
                    )
         
     | 
| 119 | 
         
             
                    gr.Code(
         
     | 
| 120 | 
         
             
                        '''
         
     | 
| 121 | 
         
            +
            from gliner import GLiNER
         
     | 
| 122 | 
         | 
| 123 | 
         
             
            model = GLiNER.from_pretrained("urchade/gliner_base")
         
     | 
| 124 | 
         | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -1,6 +1 @@ 
     | 
|
| 1 | 
         
            -
             
     | 
| 2 | 
         
            -
            transformers
         
     | 
| 3 | 
         
            -
            huggingface_hub
         
     | 
| 4 | 
         
            -
            flair
         
     | 
| 5 | 
         
            -
            seqeval
         
     | 
| 6 | 
         
            -
            tqdm
         
     | 
| 
         | 
|
| 1 | 
         
            +
            gliner
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         |