Edit model card

MIRNet low-light image enhancement

Hugging Face Spaces MIRNet-based low-light image enhancer specialized on restoring dark images from events (concerts, parties, clubs...).

Project source-code and further documentation

Documentation about pre-training, fine-tuning, model architecture, usage and all source code used for building and inference can be found in the GitHub repository of the project.
This page currently stores the PyTorch model weights and model definition, a HuggingFace pipeline will be implemented in the future.

Using the model

To use the model, you need to have the model folder, that you can dowload from this repository as well as on GitHub, present in your project folder.

Then, the following code can be used to download the model weights from HuggingFace and load them in PyTorch for downstream use of the model:

import torch
import torchvision.transforms as T
from PIL import Image
from huggingface_hub import hf_hub_download
from model.MIRNet.model import MIRNet

device = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)

# Download the model weights from the Hugging Face Hub
model_path = hf_hub_download(
    repo_id="dblasko/mirnet-low-light-img-enhancement", filename="mirnet_finetuned.pth"
)

# Load the model
model = MIRNet().to(device)
model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"])

# Use the model, for example for inference on an image
model.eval()
with torch.no_grad():
    img = Image.open("image_path.png").convert("RGB")
    img_tensor = T.Compose(
        [
            T.Resize(400),  # Adjust image resizing depending on hardware
            T.ToTensor(),
            T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
        ]
    )(img).unsqueeze(0)
    img_tensor = img_tensor.to(device)

    if img_tensor.shape[2] % 8 != 0:
        img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :]
    if img_tensor.shape[3] % 8 != 0:
        img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)]

    output = model(img_tensor)
Downloads last month
0
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train dblasko/mirnet-low-light-img-enhancement

Space using dblasko/mirnet-low-light-img-enhancement 1