Shaltiel commited on
Commit
1969465
โ€ข
1 Parent(s): 88c6751

Upload BertForJointParsing.py

Browse files
Files changed (1) hide show
  1. BertForJointParsing.py +20 -26
BertForJointParsing.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass
2
- import math
3
  from operator import itemgetter
4
  import torch
5
  from torch import nn
@@ -187,25 +187,6 @@ class BertForJointParsing(BertPreTrainedModel):
187
  )
188
 
189
  def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
190
- """
191
- Predicts various linguistic features using the DictaBERT model.
192
-
193
- This function takes a sentence or a list of sentences in Hebrew and applies the BERT model to predict multiple linguistic attributes simultaneously. These include syntax, named entity recognition (NER), morphological analysis, lexical information, and text segmentation.
194
-
195
- Parameters:
196
- sentences (Union[str, List[str]]): A single sentence or a list of sentences in Hebrew.
197
- tokenizer (BertTokenizerFast): The tokenizer used for preprocessing the input sentences.
198
- padding (str, optional): The strategy for padding sentences. Defaults to 'longest'.
199
- truncation (bool, optional): Flag to enable or disable truncation. Defaults to True.
200
- compute_syntax_mst (bool, optional): If True, computes the maximum spanning tree for syntax prediction. Defaults to True.
201
- per_token_ner (bool, optional): If True, performs NER for each token. Defaults to False.
202
- output_style (Literal['json', 'ud', 'iahlt_ud'], optional): The format of the output. Choices are 'json', 'ud' (Universal Dependencies), or 'iahlt_ud' (UD in the style of IAHLT). Defaults to 'json'.
203
-
204
- Returns:
205
- Depending on the output_style chosen, returns the linguistic analysis in the specified format.
206
-
207
- The function is integral for comprehensive linguistic analysis in applications involving Hebrew text, catering to a variety of NLP tasks.
208
- """
209
  is_single_sentence = isinstance(sentences, str)
210
  if is_single_sentence:
211
  sentences = [sentences]
@@ -315,11 +296,10 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
315
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
316
  input_ids = inputs['input_ids']
317
 
318
- predictions = torch.argmax(logits, dim=-1)
319
  batch_ret = []
320
  for batch_idx in range(len(sentences)):
321
- ret = []
322
- batch_ret.append(ret)
323
  for tok_idx in range(input_ids.shape[1]):
324
  token_id = input_ids[batch_idx, tok_idx]
325
  # ignore cls, sep, pad
@@ -328,9 +308,23 @@ def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
328
  token = tokenizer._convert_id_to_token(token_id)
329
  # wordpieces should just be appended to the previous word
330
  if token.startswith('##'):
331
- ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
332
  continue
333
- ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  return batch_ret
335
 
336
  ud_prefixes_to_pos = {
@@ -437,7 +431,7 @@ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
437
  suf_feats = word['morph']['suffix_feats']
438
  suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats.get('Gender', 'Fem,Masc')}|Number={suf_feats.get('Number', 'Sing')}|Person={suf_feats.get('Person', '3')}", "_ื”ื•ื")
439
  # for HTB, if the function is poss, then add a shel pointing to the next word
440
- if func == 'nmod:poss':
441
  intermediate_output.append(dict(word='_ืฉืœ_', lex='ืฉืœ', pos='ADP', dep=len(intermediate_output) + 2, func='case', feats='_', absolute_dep=True))
442
  # add the main suffix in
443
  intermediate_output.append(dict(word=suf, lex='ื”ื•ื', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
 
1
  from dataclasses import dataclass
2
+ import re
3
  from operator import itemgetter
4
  import torch
5
  from torch import nn
 
187
  )
188
 
189
  def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
 
296
  def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
297
  input_ids = inputs['input_ids']
298
 
299
+ predictions = torch.argsort(logits, dim=-1, descending=True)[..., :3]
300
  batch_ret = []
301
  for batch_idx in range(len(sentences)):
302
+ intermediate_ret = []
 
303
  for tok_idx in range(input_ids.shape[1]):
304
  token_id = input_ids[batch_idx, tok_idx]
305
  # ignore cls, sep, pad
 
308
  token = tokenizer._convert_id_to_token(token_id)
309
  # wordpieces should just be appended to the previous word
310
  if token.startswith('##'):
311
+ intermediate_ret[-1] = (intermediate_ret[-1][0] + token[2:], intermediate_ret[-1][1])
312
  continue
313
+ intermediate_ret.append((token, tokenizer.convert_ids_to_tokens(predictions[batch_idx, tok_idx])))
314
+
315
+ # build the final output taking into account valid letters
316
+ ret = []
317
+ batch_ret.append(ret)
318
+ for (token, lexemes) in intermediate_ret:
319
+ # must overlap on at least 2 non ืื”ื•ื™ letters
320
+ possible_lets = set(c for c in token if c not in 'ืื”ื•ื™')
321
+ final_lex = '[BLANK]'
322
+ for lex in lexemes:
323
+ if sum(c in possible_lets for c in lex) >= min([2, len(possible_lets), len([c for c in lex if c not in 'ืื”ื•ื™'])]):
324
+ final_lex = lex
325
+ break
326
+ ret.append((token, final_lex))
327
+
328
  return batch_ret
329
 
330
  ud_prefixes_to_pos = {
 
431
  suf_feats = word['morph']['suffix_feats']
432
  suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats.get('Gender', 'Fem,Masc')}|Number={suf_feats.get('Number', 'Sing')}|Person={suf_feats.get('Person', '3')}", "_ื”ื•ื")
433
  # for HTB, if the function is poss, then add a shel pointing to the next word
434
+ if func == 'nmod:poss' and s_lex != 'ืฉืœ':
435
  intermediate_output.append(dict(word='_ืฉืœ_', lex='ืฉืœ', pos='ADP', dep=len(intermediate_output) + 2, func='case', feats='_', absolute_dep=True))
436
  # add the main suffix in
437
  intermediate_output.append(dict(word=suf, lex='ื”ื•ื', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))