| import logging |
| import os |
| import torch |
| from torchvision import transforms |
| import numpy as np |
| import random |
| import cv2 |
| from PIL import Image |
|
|
|
|
| def path_to_image(path, size=(1024, 1024), color_type=["rgb", "gray"][0]): |
| if color_type.lower() == "rgb": |
| image = cv2.imread(path) |
| elif color_type.lower() == "gray": |
| image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) |
| else: |
| print("Select the color_type to return, either to RGB or gray image.") |
| return |
| if size: |
| image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) |
| if color_type.lower() == "rgb": |
| image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert("RGB") |
| else: |
| image = Image.fromarray(image).convert("L") |
| return image |
|
|
|
|
| def check_state_dict(state_dict, unwanted_prefixes=["_orig_mod.", "module."]): |
| for k, v in list(state_dict.items()): |
| for unwanted_prefix in unwanted_prefixes: |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) |
| break |
| return state_dict |
|
|
|
|
| def generate_smoothed_gt(gts): |
| epsilon = 0.001 |
| new_gts = (1 - epsilon) * gts + epsilon / 2 |
| return new_gts |
|
|
|
|
| class Logger: |
| def __init__(self, path="log.txt"): |
| self.logger = logging.getLogger("BiRefNet") |
| self.file_handler = logging.FileHandler(path, "w") |
| self.stdout_handler = logging.StreamHandler() |
| self.stdout_handler.setFormatter( |
| logging.Formatter("%(asctime)s %(levelname)s %(message)s") |
| ) |
| self.file_handler.setFormatter( |
| logging.Formatter("%(asctime)s %(levelname)s %(message)s") |
| ) |
| self.logger.addHandler(self.file_handler) |
| self.logger.addHandler(self.stdout_handler) |
| self.logger.setLevel(logging.INFO) |
| self.logger.propagate = False |
|
|
| def info(self, txt): |
| self.logger.info(txt) |
|
|
| def close(self): |
| self.file_handler.close() |
| self.stdout_handler.close() |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0.0 |
| self.avg = 0.0 |
| self.sum = 0.0 |
| self.count = 0.0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| def save_checkpoint(state, path, filename="latest.pth"): |
| torch.save(state, os.path.join(path, filename)) |
|
|
|
|
| def save_tensor_img(tenor_im, path): |
| im = tenor_im.cpu().clone() |
| im = im.squeeze(0) |
| tensor2pil = transforms.ToPILImage() |
| im = tensor2pil(im) |
| im.save(path) |
|
|
|
|
| def set_seed(seed): |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| torch.backends.cudnn.deterministic = True |
|
|