test_llama2 / custom.py
HaNguyen's picture
Update custom
83b9830
raw
history blame contribute delete
No virus
2.32 kB
"""This lobe enables the integration of huggingface pretrained Llama2 Model model plus the expanding embedding layer for additional PAD tokens .
Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html
Authors
* Pooneh Mousavi 2023
"""
import logging
from torch import Tensor
import torch
import torch.nn as nn
from speechbrain.lobes.models.huggingface_transformers.llama2 import LLAMA2
logger = logging.getLogger(__name__)
class LLAMA2_expanded(LLAMA2):
"""This lobe enables the integration of HuggingFace pretrained LLAMA2 model.
Source paper LLAMA2:
https://arxiv.org/abs/2307.09288
Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html
The model can be finetuned. It will download automatically the model from
HuggingFace or use a local path.
Arguments
---------
source : str
HuggingFace hub name: e.g "meta-llama/Llama-2-7b-chat-hf"
save_path : str
Path (dir) of the downloaded model.
freeze : bool (default: False)
If True, the model is frozen. If False, the model will be trained
alongside with the rest of the pipeline.
Example
-------
>>> model_hub = "meta-llama/Llama-2-7b-chat-hf"
>>> save_path = "savedir"
>>> model = LLAMA2(model_hub, save_path)
>>> tokens = torch.tensor([[1, 1]])
>>> attention_mask = torch.tensor([[1, 1]])
>>> outputs = model(tokens, attention_mask)
"""
def __init__(
self, *args, **kwrds
) -> None:
super().__init__( *args, **kwrds)
# Load tokenizer and add special tokens
# # Add special tokens to the tokenizer and resize model embedding
# Special tokens
#self.add_special_tokens_(
# {"pad_token": "<pad>"}
# )
def add_special_tokens_(self, attr_to_special_token,) -> None:
orig_num_tokens = len(self.tokenizer)
num_added_tokens = self.tokenizer.add_special_tokens(
attr_to_special_token # type: ignore
) # doesn't add if they are already there
if num_added_tokens > 0:
self.model.resize_token_embeddings(
new_num_tokens=orig_num_tokens + num_added_tokens
)