guillermoruiz commited on
Commit
86ff12e
1 Parent(s): 9ce54df

Upload TFBilma

Browse files
Files changed (2) hide show
  1. modeling_bilma.py +6 -3
  2. tf_model.h5 +1 -1
modeling_bilma.py CHANGED
@@ -9,7 +9,7 @@ from typing import Dict
9
  import re
10
  import unicodedata
11
 
12
- from .configuration_bilma import BilmaConfig
13
 
14
  # copied from preprocessing.py
15
  BLANK = ' '
@@ -37,6 +37,7 @@ class TFBilma(TFPreTrainedModel):
37
 
38
  def __init__(self, config):
39
  self.seq_max_length = config.seq_max_length
 
40
  super().__init__(config)
41
  #if config.weights == "spanish":
42
  # my_resources = importlib_resources.files("hf_bilma")
@@ -76,8 +77,10 @@ class TFBilma(TFPreTrainedModel):
76
  #if isinstance(tensor, dict) and len(tensor) == 0:
77
  # return self.model(self.dummy_inputs)
78
  ins = tf.cast(inputs["input_ids"], tf.float32)
79
-
80
- output = {"logits":self.model(ins)}
 
 
81
  return output
82
 
83
 
 
9
  import re
10
  import unicodedata
11
 
12
+ from configuration_bilma import BilmaConfig
13
 
14
  # copied from preprocessing.py
15
  BLANK = ' '
 
37
 
38
  def __init__(self, config):
39
  self.seq_max_length = config.seq_max_length
40
+ self.include_top = config.include_top
41
  super().__init__(config)
42
  #if config.weights == "spanish":
43
  # my_resources = importlib_resources.files("hf_bilma")
 
77
  #if isinstance(tensor, dict) and len(tensor) == 0:
78
  # return self.model(self.dummy_inputs)
79
  ins = tf.cast(inputs["input_ids"], tf.float32)
80
+ if self.include_top:
81
+ output = {"logits":self.model(ins)}
82
+ else:
83
+ output = {"last_hidden_state":self.model(ins)}
84
  return output
85
 
86
 
tf_model.h5 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6d31e357973be9bf86a3676237280b3ffe852ac994efd62d6eb67e06e36cd039
3
  size 156564220
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75330683b2e51a65402cdd6b87de8d51b817f5924bfa2e8ce2c085d15b3b841b
3
  size 156564220