AIBugHunter / statement_t5.py
MickyMike's picture
Upload 14 files
f513a95
import torch
import torch.nn as nn
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, hidden_dim):
super().__init__()
self.dense = nn.Linear(hidden_dim, hidden_dim)
self.Dropout = nn.Dropout(0.1)
self.out_proj = nn.Linear(hidden_dim, 1)
self.rnn_pool = nn.GRU(input_size=768,
hidden_size=768,
num_layers=1,
batch_first=True)
self.func_dense = nn.Linear(hidden_dim, hidden_dim)
self.func_out_proj = nn.Linear(hidden_dim, 2)
def forward(self, hidden):
x = self.Dropout(hidden)
x = self.dense(x)
x = torch.tanh(x)
x = self.Dropout(x)
x = self.out_proj(x)
out, func_x = self.rnn_pool(hidden)
func_x = func_x.squeeze(0)
func_x = self.Dropout(func_x)
func_x = self.func_dense(func_x)
func_x = torch.tanh(func_x)
func_x = self.Dropout(func_x)
func_x = self.func_out_proj(func_x)
return x.squeeze(-1), func_x
class StatementT5(nn.Module):
def __init__(self, t5, tokenizer, device, hidden_dim=768):
super(StatementT5, self).__init__()
self.max_num_statement = 155
self.word_embedding = t5.shared
self.rnn_statement_embedding = nn.GRU(input_size=768,
hidden_size=768,
num_layers=1,
batch_first=True)
self.t5 = t5
self.tokenizer = tokenizer
self.device = device
# CLS head
self.classifier = ClassificationHead(hidden_dim=hidden_dim)
def forward(self, input_ids, statement_mask, labels=None, func_labels=None):
statement_mask = statement_mask[:, :self.max_num_statement]
if self.training:
embed = self.word_embedding(input_ids)
inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device)
for i in range(len(embed)):
statement_of_tokens = embed[i]
out, statement_embed = self.rnn_statement_embedding(statement_of_tokens)
inputs_embeds[i, :, :] = statement_embed
inputs_embeds = inputs_embeds[:, :self.max_num_statement, :]
rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state
logits, func_logits = self.classifier(rep)
loss_fct = nn.CrossEntropyLoss()
statement_loss = loss_fct(logits, labels)
loss_fct_2 = nn.CrossEntropyLoss()
func_loss = loss_fct_2(func_logits, func_labels)
return statement_loss, func_loss
else:
embed = self.word_embedding(input_ids)
inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device)
for i in range(len(embed)):
statement_of_tokens = embed[i]
out, statement_embed = self.rnn_statement_embedding(statement_of_tokens)
inputs_embeds[i, :, :] = statement_embed
inputs_embeds = inputs_embeds[:, :self.max_num_statement, :]
rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state
logits, func_logits = self.classifier(rep)
probs = torch.sigmoid(logits)
func_probs = torch.softmax(func_logits, dim=-1)
return probs, func_probs