sin-qac-model / modeling_query_completion.py
lv12's picture
Upload folder using huggingface_hub
924e4e0 verified
"""
HuggingFace-compatible model definition for Query Auto-Completion.
"""
import torch
from typing import Union, Tuple
from transformers import PretrainedConfig, PreTrainedModel
from .model import QueryCompletionModel as BaseQueryCompletionModel
class QueryCompletionConfig(PretrainedConfig):
"""Configuration for Query Auto-Completion model."""
model_type = "query-completion"
def __init__(
self,
vocab_size: int = 384,
embed_dim: int = 256,
num_filters: int = 64,
filter_sizes: list = None,
num_heads: int = 4,
num_transformer_layers: int = 2,
use_pretrained_embeddings: bool = True,
pretrained_model_name: str = "google/byt5-small",
dropout: float = 0.1,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.num_filters = num_filters
self.filter_sizes = filter_sizes or [3, 4, 5]
self.num_heads = num_heads
self.num_transformer_layers = num_transformer_layers
self.use_pretrained_embeddings = use_pretrained_embeddings
self.pretrained_model_name = pretrained_model_name
self.dropout = dropout
class QueryCompletionModelForHub(PreTrainedModel):
"""HuggingFace wrapper around the Query Auto-Completion model."""
config_class = QueryCompletionConfig
base_model_prefix = "query_completion"
supports_gradient_checkpointing = False
def __init__(self, config: QueryCompletionConfig):
super().__init__(config)
self.model = BaseQueryCompletionModel(
vocab_size=config.vocab_size,
embed_dim=config.embed_dim,
num_filters=config.num_filters,
num_heads=config.num_heads,
num_transformer_layers=config.num_transformer_layers,
use_pretrained_embeddings=config.use_pretrained_embeddings,
pretrained_model_name=config.pretrained_model_name,
)
self.post_init()
def forward(
self,
prefix_ids: torch.Tensor,
candidate_ids: torch.Tensor,
return_dict: bool = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
return self.model(prefix_ids, candidate_ids)