import logging import math import os from contextlib import contextmanager from transformers import StoppingCriteria, StoppingCriteriaList import timm.models.hub as timm_hub import torch import torch.distributed as dist import torch.nn as nn def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def download_cached_file(url, check_hash=True, progress=False): """ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. """ def get_cached_file_path(): # a hack to sync the file path across processes parts = torch.hub.urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(timm_hub.get_cache_dir(), filename) return cached_file if is_main_process(): timm_hub.download_cached_file(url, check_hash, progress) if is_dist_avail_and_initialized(): dist.barrier() return get_cached_file_path() @contextmanager def all_logging_disabled(highest_level=logging.CRITICAL): """ A context manager that will prevent any logging messages triggered during the body from being processed. :param highest_level: the maximum logging level in use. This would only need to be changed if a custom level greater than CRITICAL is defined. """ # two kind-of hacks here: # * can't get the highest logging level in effect => delegate to the user # * can't get the current module-level override => use an undocumented # (but non-private!) interface previous_level = logging.root.manager.disable logging.disable(highest_level) try: yield finally: logging.disable(previous_level) class LoRALinear(nn.Linear): def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, lora_r=8, lora_alpha=16, lora_dropout=0.05, **kwargs ) -> None: super().__init__(in_features, out_features, bias, device, dtype) self.lora_r = lora_r self.lora_alpha = lora_alpha if lora_dropout > 0.0: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x self.lora_scaling = self.lora_alpha / self.lora_r self.lora_A = nn.Linear( in_features, self.lora_r, bias=False, device=device, dtype=dtype ) self.lora_B = nn.Linear( self.lora_r, out_features, bias=False, device=device, dtype=dtype ) self.reset_parameters() def reset_parameters(self): if hasattr(self, "lora_A"): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_B.weight) def forward(self, x): orig_type = x.dtype res = super().forward(x) x = x.float() res += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.lora_scaling return res.to(orig_type) class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all((stop == input_ids[:, -len(stop) :])).item(): return True return False