Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Sep 5 10:01:39 2023 | |
@author: peter | |
""" | |
import transformers | |
import qarac.models.layers.GlobalAttentionPoolingHead | |
class QaracEncoderModel(transformers.PreTrainedModel): | |
def __init__(self,path): | |
""" | |
Creates the endocer model | |
Parameters | |
---------- | |
base_model : transformers.TFRobertaModel | |
The base model | |
Returns | |
------- | |
None. | |
""" | |
config = transformers.PretrainedConfig.from_pretrained(path) | |
super(QaracEncoderModel,self).__init__(config) | |
self.encoder = transformers.RobertaModel.from_pretrained(path) | |
self.head = qarac.models.layers.GlobalAttentionPoolingHead.GlobalAttentionPoolingHead(config) | |
def forward(self,input_ids, | |
attention_mask=None): | |
""" | |
Vectorizes a tokenised text | |
Parameters | |
---------- | |
inputs : tensorflow.Tensor | |
tokenized text to endode | |
Returns | |
------- | |
tensorflow.Tensor | |
Vector representing the document | |
""" | |
if attention_mask is None and 'attention_mask' in input_ids: | |
(input_ids,attention_mask) = (input_ids['input_ids'],input_ids['attention_mask']) | |
print('input_ids',input_ids.device) | |
print('attention_mask',attention_mask.device) | |
return self.head(self.encoder(input_ids, | |
attention_mask).last_hidden_state, | |
attention_mask) | |