|
|
|
|
|
""" PyTorch CodeT5+ matching models. |
|
The implementation is based on transformers.models.t5.modeling_t5 by adding a projection layer on T5EncoderModel |
|
""" |
|
|
|
from typing import Optional, Tuple, Union |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from transformers import T5ForConditionalGeneration |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutput, |
|
) |
|
from configuration_codet5p_matching import CodeT5pMatchingConfig |
|
|
|
|
|
class CodeT5pMatchingModel(T5ForConditionalGeneration): |
|
config_class = CodeT5pMatchingConfig |
|
|
|
authorized_missing_keys = [ |
|
r"encoder.embed_tokens.weight", |
|
] |
|
|
|
def __init__(self, config: CodeT5pMatchingConfig): |
|
super().__init__(config) |
|
self.proj = nn.Linear(config.d_model, config.embed_dim) |
|
self.itm_head = nn.Linear(config.d_model, 2) |
|
|