Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Sep 5 15:30:06 2023 | |
@author: peter | |
""" | |
import torch | |
import qarac.models.QaracEncoderModel | |
import qarac.models.QaracDecoderModel | |
class QaracTrainerModel(torch.nn.Module): | |
def __init__(self,base_model_path,tokenizer): | |
""" | |
Sets up the Trainer model | |
Parameters | |
---------- | |
base_encoder_model : transformers.RobertaModel | |
Base model for encoders. | |
base_decoder_model : transformers.RobertaModel | |
Base model for decoder | |
tokenizer : transformers.RobertaTokenizer | |
Tokeniaer for decoder | |
Returns | |
------- | |
None. | |
""" | |
super(QaracTrainerModel,self).__init__() | |
self.question_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path) | |
self.answer_encoder = qarac.models.QaracEncoderModel.QaracEncoderModel(base_model_path) | |
config = self.answer_encoder.config | |
config.is_decoder = True | |
self.decoder = qarac.models.QaracDecoderModel.QaracDecoderModel(base_model_path, | |
config, | |
tokenizer) | |
self.cosine = torch.nn.CosineSimilarity(dim=1,eps=1.0e-12) | |
def forward(self, | |
all_text, | |
offset_text, | |
question, | |
answer, | |
proposition0, | |
proposition1, | |
conclusion_offset, | |
statement0, | |
statement1): | |
""" | |
Generates training objectives from data | |
Parameters | |
---------- | |
all_text : torch.tensor | |
Tokenized text for encode-decode objective | |
offset_text : torch.tensor | |
As above, prefixed with <s> | |
question : torch.tensor | |
tokenized question for question ansering objective | |
answer : torch.tensor | |
tokenized answer for question answering objective | |
proposition0 : torch.tensor | |
tokenized proposition for reasoning objective. | |
proposition1 : otrch.tensor | |
tokenized proposition for reasoning objective | |
conclusion_offset : torch.tensor | |
tokeniaed conclusion for reasoning objective, prefixed with <s> | |
statement0 : torch.tensor | |
tokenized statement for consistency objective | |
statement1 : torch.tensor | |
tokenized.statement for consistency ogjective | |
Returns | |
------- | |
encode_decode : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions | |
Predicted text for encode-decode task | |
question_answering : torch.tensor | |
Difference between encoded question and encoded answeer | |
reasoning : transformers.modeling_outputs.CausalLMOutputWithCrossAttentions | |
Predicted text for reasoning objective | |
consistency : torch.tensor | |
Cosine similarity of vectorized statements | |
""" | |
encode_decode = self.decoder((self.answer_encoder(all_text), | |
offset_text)) | |
question_answering = self.question_encoder(question) - self.answer_encoder(answer) | |
reasoning = self.decoder((self.answer_encoder(proposition0) | |
+self.answer_encoder(proposition1), | |
conclusion_offset)) | |
s0 = self.answer_encoder(statement0) | |
s1 = self.answer_encoder(statement1) | |
consistency = self.cosine(s0,s1) | |
return (encode_decode,question_answering,reasoning,consistency) | |