diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..3b9dd9ff1752c7cf62b5ac7b362c599001b0308d 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+image/low_haze_rain_00469_01_lq.png filter=lfs diff=lfs merge=lfs -text
+image/low_haze_snow_00337_01_lq.png filter=lfs diff=lfs merge=lfs -text
+img_file/abstract.jpg filter=lfs diff=lfs merge=lfs -text
+img_file/OneRestore_poster.png filter=lfs diff=lfs merge=lfs -text
+img_file/pipeline.jpg filter=lfs diff=lfs merge=lfs -text
+img_file/real.jpg filter=lfs diff=lfs merge=lfs -text
+output/low_haze_rain_00469_01_lq.png filter=lfs diff=lfs merge=lfs -text
+output/low_haze_snow_00337_01_lq.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a773941b82f878d333825567ca496768e8251730
--- /dev/null
+++ b/README.md
@@ -0,0 +1,298 @@
+
+
+
+
+
+ # [ECCV 2024] OneRestore: A Universal Restoration Framework for Composite Degradation
+
+
+
+
+[![ArXiv](https://img.shields.io/badge/OneRestore-ArXiv-red.svg)](https://arxiv.org/abs/2407.04621)
+[![Paper](https://img.shields.io/badge/OneRestore-Paper-purple.svg)](https://arxiv.org/abs/2407.04621)
+[![Web](https://img.shields.io/badge/OneRestore-Web-blue.svg)](https://gy65896.github.io/projects/ECCV2024_OneRestore/index.html)
+[![Poster](https://img.shields.io/badge/OneRestore-Poster-green.svg)](https://github.com/gy65896/OneRestore/blob/main/img_file/OneRestore_poster.png)
+[![Video](https://img.shields.io/badge/OneRestore-Video-orange.svg)](https://www.youtube.com/watch?v=AFr5tZdPlZ4)
+
+[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fgy65896%2FOneRestore&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com)
+[![Hugging Face Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue)](https://huggingface.co/spaces/gy65896/OneRestore)
+[![Closed Issues](https://img.shields.io/github/issues-closed/gy65896/OneRestore)](https://github.com/gy65896/OneRestore/issues?q=is%3Aissue+is%3Aclosed)
+[![Open Issues](https://img.shields.io/github/issues/gy65896/OneRestore)](https://github.com/gy65896/OneRestore/issues)
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/low-light-image-enhancement-on-lol)](https://paperswithcode.com/sota/low-light-image-enhancement-on-lol?p=onerestore-a-universal-restoration-framework)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/image-dehazing-on-sots-outdoor)](https://paperswithcode.com/sota/image-dehazing-on-sots-outdoor?p=onerestore-a-universal-restoration-framework)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/rain-removal-on-did-mdn)](https://paperswithcode.com/sota/rain-removal-on-did-mdn?p=onerestore-a-universal-restoration-framework)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/snow-removal-on-snow100k)](https://paperswithcode.com/sota/snow-removal-on-snow100k?p=onerestore-a-universal-restoration-framework)
+
+
+
+
+
+
+---
+>**OneRestore: A Universal Restoration Framework for Composite Degradation**
[Yu Guo](https://scholar.google.com/citations?user=klYz-acAAAAJ&hl=zh-CN)† , [Yuan Gao](https://scholar.google.com.hk/citations?user=4JpRnU4AAAAJ&hl=zh-CN)† , [Yuxu Lu](https://scholar.google.com.hk/citations?user=XXge2_0AAAAJ&hl=zh-CN), [Huilin Zhu](https://scholar.google.com.hk/citations?hl=zh-CN&user=fluPrxcAAAAJ), [Ryan Wen Liu](http://mipc.whut.edu.cn/index.html)* , [Shengfeng He](http://www.shengfenghe.com/)*
+(† Co-first Author, * Corresponding Author)
+>European Conference on Computer Vision
+
+> **Abstract:** *In real-world scenarios, image impairments often manifest as composite degradations, presenting a complex interplay of elements such as low light, haze, rain, and snow. Despite this reality, existing restoration methods typically target isolated degradation types, thereby falling short in environments where multiple degrading factors coexist. To bridge this gap, our study proposes a versatile imaging model that consolidates four physical corruption paradigms to accurately represent complex, composite degradation scenarios. In this context, we propose OneRestore, a novel transformer-based framework designed for adaptive, controllable scene restoration. The proposed framework leverages a unique cross-attention mechanism, merging degraded scene descriptors with image features, allowing for nuanced restoration. Our model allows versatile input scene descriptors, ranging from manual text embeddings to automatic extractions based on visual attributes. Our methodology is further enhanced through a composite degradation restoration loss, using extra degraded images as negative samples to fortify model constraints. Comparative results on synthetic and real-world datasets demonstrate OneRestore as a superior solution, significantly advancing the state-of-the-art in addressing complex, composite degradations.*
+---
+
+## News 🚀
+* **2024.09.07**: [Hugging Face Demo](https://huggingface.co/spaces/gy65896/OneRestore) is released.
+* **2024.09.05**: Video and poster are released.
+* **2024.09.04**: Code for data synthesis is released.
+* **2024.07.27**: Code for multiple GPUs training is released.
+* **2024.07.20**: [New Website](https://gy65896.github.io/projects/ECCV2024_OneRestore) has been created.
+* **2024.07.10**: [Paper](https://arxiv.org/abs/2407.04621) is released on ArXiv.
+* **2024.07.07**: Code and Dataset are released.
+* **2024.07.02**: OneRestore is accepted by [ECCV2024](https://eccv.ecva.net/).
+
+## Network Architecture
+
+
+
+
+
+
+## Quick Start
+
+### Install
+
+- python 3.7
+- cuda 11.7
+
+```
+# git clone this repository
+git clone https://github.com/gy65896/OneRestore.git
+cd OneRestore
+
+# create new anaconda env
+conda create -n onerestore python=3.7
+conda activate onerestore
+
+# download ckpts
+put embedder_model.tar and onerestore_cdd-11.tar in ckpts folder
+
+# install pytorch (Take cuda 11.7 as an example to install torch 1.13)
+pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
+
+# install other packages
+pip install -r requirements.txt
+pip install gensim
+```
+
+### Pretrained Models
+
+Please download our pre-trained models and put them in `./ckpts`.
+
+| Model | Description
+| :--- | :----------
+|[embedder_model.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpnhSQFIoD9msXWOA?e=aUpHOT) | Text/Visual Embedder trained on our CDD-11.
+|[onerestore_cdd-11.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpmWkGBku6oj33efg?e=7yUGfN) | OneRestore trained on our CDD-11.
+|[onerestore_real.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpi-iJOyN6OSYqiaA?e=QFfMeL) | OneRestore trained on our CDD-11 for Real Scenes.
+|[onerestore_lol.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpkSoVB1j-wYHFpHg?e=0gR9pn) | OneRestore trained on LOL (low light enhancement benchmark).
+|[onerestore_reside_ots.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpjGh8KjfM_QIJzEw?e=zabGTw) | OneRestore trained on RESIDE-OTS (image dehazing benchmark).
+|[onerestore_rain1200.tar](https://1drv.ms/u/s!As3rCDROnrbLgqplAFHv6B348jarGA?e=GuduMT) | OneRestore trained on Rain1200 (image deraining benchmark).
+|[onerestore_snow100k.tar](https://1drv.ms/u/s!As3rCDROnrbLgqphsWWxLZN_7JFJDQ?e=pqezzo) | OneRestore trained on Snow100k-L (image desnowing benchmark).
+
+### Inference
+
+We provide two samples in `./image` for the quick inference:
+
+```
+python test.py --embedder-model-path ./ckpts/embedder_model.tar --restore-model-path ./ckpts/onerestore_cdd-11.tar --input ./image/ --output ./output/ --concat
+```
+
+You can also input the prompt to perform controllable restoration. For example:
+
+```
+python test.py --embedder-model-path ./ckpts/embedder_model.tar --restore-model-path ./ckpts/onerestore_cdd-11.tar --prompt low_haze --input ./image/ --output ./output/ --concat
+```
+
+## Training
+
+### Prepare Dataset
+
+We provide the download link of our Composite Degradation Dataset with 11 types of degradation ([CDD-11](https://1drv.ms/f/s!As3rCDROnrbLgqpezG4sao-u9ddDhw?e=A0REHx)).
+
+Preparing the train and test datasets as follows:
+
+```
+./data/
+|--train
+| |--clear
+| | |--000001.png
+| | |--000002.png
+| |--low
+| |--haze
+| |--rain
+| |--snow
+| |--low_haze
+| |--low_rain
+| |--low_snow
+| |--haze_rain
+| |--haze_snow
+| |--low_haze_rain
+| |--low_haze_snow
+|--test
+```
+### Train Model
+
+**1. Train Text/Visual Embedder by**
+
+```
+python train_Embedder.py --train-dir ./data/CDD-11_train --test-dir ./data/CDD-11_test --check-dir ./ckpts --batch 256 --num-workers 0 --epoch 200 --lr 1e-4 --lr-decay 50
+```
+
+**2. Remove the optimizer weights in the Embedder model file by**
+
+```
+python remove_optim.py --type Embedder --input-file ./ckpts/embedder_model.tar --output-file ./ckpts/embedder_model.tar
+```
+
+**3. Generate the `dataset.h5` file for training OneRestore by**
+
+```
+python makedataset.py --train-path ./data/CDD-11_train --data-name dataset.h5 --patch-size 256 --stride 200
+```
+
+**4. Train OneRestore model by**
+
+- **Single GPU**
+
+```
+python train_OneRestore_single-gpu.py --embedder-model-path ./ckpts/embedder_model.tar --save-model-path ./ckpts --train-input ./dataset.h5 --test-input ./data/CDD-11_test --output ./result/ --epoch 120 --bs 4 --lr 1e-4 --adjust-lr 30 --num-works 4
+```
+
+- **Multiple GPUs**
+
+Assuming you train the OneRestore model using 4 GPUs (e.g., 0, 1, 2, and 3), you can use the following command. Note that the number of nproc_per_node should equal the number of GPUs.
+
+```
+CUDA_VISIBLE_DEVICES=0, 1, 2, 3 torchrun --nproc_per_node=4 train_OneRestore_multi-gpu.py --embedder-model-path ./ckpts/embedder_model.tar --save-model-path ./ckpts --train-input ./dataset.h5 --test-input ./data/CDD-11_test --output ./result/ --epoch 120 --bs 4 --lr 1e-4 --adjust-lr 30 --num-works 4
+```
+
+**5. Remove the optimizer weights in the OneRestore model file by**
+
+```
+python remove_optim.py --type OneRestore --input-file ./ckpts/onerestore_model.tar --output-file ./ckpts/onerestore_model.tar
+```
+
+### Customize your own composite degradation dataset
+
+**1. Prepare raw data**
+
+ - Collect your own clear images.
+ - Generate the depth map based on [MegaDepth](https://github.com/zhengqili/MegaDepth).
+ - Generate the light map based on [LIME](https://github.com/estija/LIME).
+ - Generate the rain mask database based on [RainStreakGen](https://github.com/liruoteng/RainStreakGen?tab=readme-ov-file).
+ - Download the snow mask database from [Snow100k](https://sites.google.com/view/yunfuliu/desnownet).
+
+A generated example is as follows:
+
+| Clear Image | Depth Map | Light Map | Rain Mask | Snow Mask
+| :--- | :---| :---| :--- | :---
+| | | | |
+
+(Note: The rain and snow masks do not require strict alignment with the image.)
+
+ - Prepare the dataset as follows:
+
+```
+./syn_data/
+|--data
+| |--clear
+| | |--000001.png
+| | |--000002.png
+| |--depth_map
+| | |--000001.png
+| | |--000002.png
+| |--light_map
+| | |--000001.png
+| | |--000002.png
+| |--rain_mask
+| | |--aaaaaa.png
+| | |--bbbbbb.png
+| |--snow_mask
+| | |--cccccc.png
+| | |--dddddd.png
+|--out
+```
+
+**2. Generate composite degradation images**
+
+ - low+haze+rain
+
+```
+python syn_data.py --hq-file ./data/clear/ --light-file ./data/light_map/ --depth-file ./data/depth_map/ --rain-file ./data/rain_mask/ --snow-file ./data/snow_mask/ --out-file ./out/ --low --haze --rain
+```
+
+ - low+haze+snow
+
+```
+python syn_data.py --hq-file ./data/clear/ --light-file ./data/light_map/ --depth-file ./data/depth_map/ --rain-file ./data/rain_mask/ --snow-file ./data/snow_mask/ --out-file ./out/ --low --haze --snow
+```
+(Note: The degradation types can be customized according to specific needs.)
+
+| Clear Image | low+haze+rain | low+haze+snow
+| :--- | :--- | :---
+| | |
+
+## Performance
+
+### CDD-11
+
+| Types | Methods | Venue & Year | PSNR ↑ | SSIM ↑ | #Params |
+|-------------------|-----------------------------------------------|--------------|----------|----------|------------|
+| Input | [Input](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuNlQAAAAABf9KaFodlfC8H-K_MNiriFw?e=SiOrWU) | | 16.00 | 0.6008 | - |
+| One-to-One | [MIRNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuMlQAAAAABBzDLjLu69noXflImQ2V9ng?e=4wohVK) | ECCV2020 | 25.97 | 0.8474 | 31.79M |
+| One-to-One | [MPRNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuLlQAAAAAB_iz3hjLHZDMi-RyxHKgDDg?e=SwSQML) | CVPR2021 | 25.47 | 0.8555 | 15.74M |
+| One-to-One | [MIRNetv2](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuQlQAAAAAB2miyepdTE3qdy4z2-LM4pg?e=moXVAR) | TPAMI2022 | 25.37 | 0.8335 | 5.86M |
+| One-to-One | [Restormer](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuPlQAAAAABE86t03kpAVm_TZDIBPKolw?e=vHAR7A) | CVPR2022 | 26.99 | 0.8646 | 26.13M |
+| One-to-One | [DGUNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuOlQAAAAABZkHj8tMamqaGhQ0w4VwFrg?e=lfDUlx) | CVPR2022 | 26.92 | 0.8559 | 17.33M |
+| One-to-One | [NAFNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/EWm9jiJiZLlLgq1trYO67EsB42LrjGpepvpS4oLqKnj8xg?e=5Efa4W) | ECCV2022 | 24.13 | 0.7964 | 17.11M |
+| One-to-One | [SRUDC](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuWlQAAAAABf9RNAUZH_xL6wF4aODWKqA?e=h4EqVN) | ICCV2023 | 27.64 | 0.8600 | 6.80M |
+| One-to-One | [Fourmer](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuXlQAAAAABQKrbA47G8kMD2cf7Chq5EQ?e=vOiWV0) | ICML2023 | 23.44 | 0.7885 | 0.55M |
+| One-to-One | [OKNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuVlQAAAAABSMzfS1xEOxLeuvw8HsGyMw?e=jRmf9t) | AAAI2024 | 26.33 | 0.8605 | 4.72M |
+| One-to-Many | [AirNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMualQAAAAABYJ96PX0fipkP93zRXN_NVw?e=sXFOl8) | CVPR2022 | 23.75 | 0.8140 | 8.93M |
+| One-to-Many | [TransWeather](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuZlQAAAAABoBiLjwJ8L2kl6rGQO5PeJA?e=msprhI) | CVPR2022 | 23.13 | 0.7810 | 21.90M |
+| One-to-Many | [WeatherDiff](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuYlQAAAAABxdWbznZA1CQ0Bh1JH_ze-A?e=LEkcZw) | TPAMI2023 | 22.49 | 0.7985 | 82.96M |
+| One-to-Many | [PromptIR](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMublQAAAAAB9aGo3QK-WlKkL5ItITW9Hg?e=wXrJf1) | NIPS2023 | 25.90 | 0.8499 | 38.45M |
+| One-to-Many | [WGWSNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMudlQAAAAABi3HUMldxdoLHgDcUNoWMPw?e=z0qjAH) | CVPR2023 | 26.96 | 0.8626 | 25.76M |
+| One-to-Composite | [OneRestore](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuclQAAAAABSmNvDBKR1u5rDtqQnZ8X7A?e=OcnrjY) | ECCV2024 | 28.47 | 0.8784 | 5.98M |
+| One-to-Composite | [OneRestore† ](https://1drv.ms/u/c/cbb69e4e3408ebcd/EVM43y_W_WxAjrZqZdK9sfoBk1vpSzKilG0m7T-3i3la-A?e=dbNsD3) | ECCV2024 | 28.72 | 0.8821 | 5.98M |
+
+[Indicator calculation code](https://github.com/gy65896/OneRestore/blob/main/img_file/cal_psnr_ssim.py) and [numerical results](https://github.com/gy65896/OneRestore/blob/main/img_file/metrics_CDD-11_psnr_ssim.xlsx) can be download here.
+
+
+
+
+
+
+### Real Scene
+
+
+
+
+
+
+### Controllability
+
+
+
+
+
+
+
+## Citation
+
+```
+@inproceedings{guo2024onerestore,
+ title={OneRestore: A Universal Restoration Framework for Composite Degradation},
+ author={Guo, Yu and Gao, Yuan and Lu, Yuxu and Liu, Ryan Wen and He, Shengfeng},
+ booktitle={European Conference on Computer Vision},
+ year={2024}
+}
+```
+
+#### If you have any questions, please get in touch with me (guoyu65896@gmail.com).
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8525e801c9900a2730ef24d84d7352ae9dc3562
--- /dev/null
+++ b/app.py
@@ -0,0 +1,89 @@
+
+import torch
+import gradio as gr
+from torchvision import transforms
+from PIL import Image
+import numpy as np
+from utils.utils import load_restore_ckpt, load_embedder_ckpt
+import os
+from gradio_imageslider import ImageSlider
+
+# Enforce CPU usage
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+embedder_model_path = "ckpts/embedder_model.tar" # Update with actual path to embedder checkpoint
+restorer_model_path = "ckpts/onerestore_cdd-11.tar" # Update with actual path to restorer checkpoint
+
+# Load models on CPU only
+embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=embedder_model_path)
+restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=restorer_model_path)
+
+# Define image preprocessing and postprocessing
+transform_resize = transforms.Compose([
+ transforms.Resize([224,224]),
+ transforms.ToTensor()
+ ])
+
+
+def postprocess_image(tensor):
+ image = tensor.squeeze(0).cpu().detach().numpy()
+ image = (image) * 255 # Assuming output in [-1, 1], rescale to [0, 255]
+ image = np.clip(image, 0, 255).astype("uint8") # Clip values to [0, 255]
+ return Image.fromarray(image.transpose(1, 2, 0)) # Reorder to (H, W, C)
+
+# Define the enhancement function
+def enhance_image(image, degradation_type=None):
+ # Preprocess the image
+ input_tensor = torch.Tensor((np.array(image)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+ lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+ lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Generate embedding
+ if degradation_type == "auto" or degradation_type is None:
+ text_embedding, _, [text] = embedder(lq_em, 'image_encoder')
+ else:
+ text_embedding, _, [text] = embedder([degradation_type], 'text_encoder')
+
+ # Model inference
+ with torch.no_grad():
+ enhanced_tensor = restorer(input_tensor, text_embedding)
+
+ # Postprocess the output
+ return (image, postprocess_image(enhanced_tensor)), text
+
+# Define the Gradio interface
+def inference(image, degradation_type=None):
+ return enhance_image(image, degradation_type)
+
+#### Image,Prompts examples
+examples = [
+ ['image/low_haze_rain_00469_01_lq.png'],
+ ['image/low_haze_snow_00337_01_lq.png'],
+ ]
+
+
+
+# Create the Gradio app interface using updated API
+interface = gr.Interface(
+ fn=inference,
+ inputs=[
+ gr.Image(type="pil", value="image/low_haze_rain_00469_01_lq.png"), # Image input
+ gr.Dropdown(['auto', 'low', 'haze', 'rain', 'snow',\
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow'], label="Degradation Type", value="auto") # Manual or auto degradation
+ ],
+ outputs=[
+ ImageSlider(label="Restored Image",
+ type="pil",
+ show_download_button=True,
+ ), # Enhanced image outputImageSlider(type="pil", show_download_button=True, ),
+ gr.Textbox(label="Degradation Type") # Display the estimated degradation type
+ ],
+ title="Image Restoration with OneRestore",
+ description="Upload an image and enhance it using OneRestore model. You can choose to let the model automatically estimate the degradation type or set it manually.",
+ examples=examples,
+)
+
+# Launch the app
+if __name__ == "__main__":
+ interface.launch()
diff --git a/ckpts/ckpts_file.txt b/ckpts/ckpts_file.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/data_file.txt b/data/data_file.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/image/low_haze_rain_00469_01_lq.png b/image/low_haze_rain_00469_01_lq.png
new file mode 100644
index 0000000000000000000000000000000000000000..c51701d44d84d31bb162726daf9bbb36b61fe88c
--- /dev/null
+++ b/image/low_haze_rain_00469_01_lq.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac5c71a539806d961d33b98e39c04c70be1a01b27a457d00493be4132b7facdf
+size 1643711
diff --git a/image/low_haze_snow_00337_01_lq.png b/image/low_haze_snow_00337_01_lq.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab962dce85dee948adfe259972f82440af82b222
--- /dev/null
+++ b/image/low_haze_snow_00337_01_lq.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b89f728f4b9498d7fcd15ab79d6a46ed76eb490a6e9971e7c2cab071b8f8cc20
+size 1497520
diff --git a/img_file/OneRestore_poster.png b/img_file/OneRestore_poster.png
new file mode 100644
index 0000000000000000000000000000000000000000..171bdebb80b27922407a95a06d4e87ecc96c617f
--- /dev/null
+++ b/img_file/OneRestore_poster.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86ee7b33d4e6b3024b12d60eb420a58b4f3b1cccb40f0569440a46e93daf816d
+size 11982388
diff --git a/img_file/abstract.jpg b/img_file/abstract.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3ac8ac3b47c9b5856a60355aa1fbaaba2cc2751a
--- /dev/null
+++ b/img_file/abstract.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c3e79021de0ed441274a97a4c141c64e730293a95af65c3d963ee5cd3205ace1
+size 1801483
diff --git a/img_file/cal_psnr_ssim.py b/img_file/cal_psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d83a4e68ff77811228253528b21b2fcad0de521
--- /dev/null
+++ b/img_file/cal_psnr_ssim.py
@@ -0,0 +1,96 @@
+import os
+import cv2
+import numpy as np
+import random
+from skimage.metrics import peak_signal_noise_ratio as compare_psnr
+from skimage.metrics import mean_squared_error as compare_mse
+from skimage.metrics import structural_similarity as compare_ssim
+# Modified function to add progress display using tqdm for better progress tracking
+from tqdm import tqdm
+import pandas as pd
+# Updated function with progress display for PSNR and SSIM calculation
+def calculate_psnr_ssim_with_progress(clear_folder, methods, degradation_types, win_size=7):
+ # Get list of all clear images
+ img_list = [img for img in os.listdir(clear_folder) if img.endswith('.png')]
+
+ # Initialize matrices to store mean PSNR and SSIM values
+ psnr_matrix = np.zeros((len(methods), len(degradation_types)))
+ ssim_matrix = np.zeros((len(methods), len(degradation_types)))
+
+ # Total number of tasks for progress tracking
+ total_tasks = len(methods) * len(degradation_types) * len(img_list)
+ print(len(methods), len(degradation_types), len(img_list))
+
+ # Create a progress bar
+ with tqdm(total=total_tasks, desc="Processing Images", unit="task") as pbar:
+ # Loop over methods
+ for k, method in enumerate(methods):
+ print(f"Processing method: {method}")
+
+ # Loop over degradation types
+ for j, degradation_type in enumerate(degradation_types):
+ psnr_values = []
+ ssim_values = []
+
+ # Loop over each image in the clear folder
+ for img_name in img_list:
+ clear_img_path = os.path.join(clear_folder, img_name)
+ degraded_img_path = f'./{method}/{degradation_type}/{img_name}'
+
+ # Read the clear and degraded images
+ clear_img = cv2.imread(clear_img_path) / 255.0
+ degraded_img = cv2.imread(degraded_img_path) / 255.0
+
+ # Ensure the images are read correctly
+ if clear_img is not None and degraded_img is not None:
+ # Compute PSNR and SSIM between clear and degraded image
+ psnr_value = compare_psnr(clear_img, degraded_img, data_range=1.0)
+
+ # Compute SSIM with specified window size and for multichannel images
+ ssim_value = compare_ssim(clear_img, degraded_img, multichannel=True,
+ win_size=min(win_size, clear_img.shape[0], clear_img.shape[1]),
+ channel_axis=-1, data_range=1.0)
+
+ # Store values
+ psnr_values.append(psnr_value)
+ ssim_values.append(ssim_value)
+
+ # Update progress bar after processing each image
+ pbar.update(1)
+
+ # Calculate mean PSNR and SSIM for the current method and degradation type
+ if psnr_values:
+ psnr_matrix[k, j] = np.mean(psnr_values)
+ if ssim_values:
+ ssim_matrix[k, j] = np.mean(ssim_values)
+
+ return psnr_matrix, ssim_matrix
+
+def save_matrices_to_excel(psnr_matrix, ssim_matrix, methods, degradation_types, output_file='metrics.xlsx'):
+ # Create DataFrames for PSNR and SSIM matrices
+ psnr_df = pd.DataFrame(psnr_matrix, index=methods, columns=degradation_types)
+ ssim_df = pd.DataFrame(ssim_matrix, index=methods, columns=degradation_types)
+
+ # Create a writer to write both DataFrames to the same Excel file
+ with pd.ExcelWriter(output_file) as writer:
+ psnr_df.to_excel(writer, sheet_name='PSNR')
+ ssim_df.to_excel(writer, sheet_name='SSIM')
+
+ print(f'Matrices saved to {output_file}')
+
+# Define the parameters
+clear_folder = './00_gt'
+methods = ['01_input', '02_MIRNet', '03_MPRNet', '04_MIRNetv2', '05_Restormer',
+ '06_DGUNet', '07_NAFNet', '08_SRUDC', '09_Fourmer', '10_OKNet', '11_AirNet',
+ '12_TransWeather', '13_WeatherDiff', '14_PromptIR', '15_WGWSNet', '16_OneRestore_visual', '17_OneRestore']
+degradation_types = ['low', 'haze', 'rain', 'snow', 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow']
+
+# This is the function that will be used to calculate the PSNR and SSIM values across methods and degradation types
+# To use the function, uncomment the line below and ensure the file paths are set correctly in your environment
+
+
+psnr_matrix, ssim_matrix = calculate_psnr_ssim_with_progress(clear_folder, methods, degradation_types)
+save_matrices_to_excel(psnr_matrix, ssim_matrix, methods, degradation_types)
+
+
+
diff --git a/img_file/clear_img.jpg b/img_file/clear_img.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8b7283c9a2fe7bc762dee7cf384318669cb1499f
Binary files /dev/null and b/img_file/clear_img.jpg differ
diff --git a/img_file/control1.jpg b/img_file/control1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4332b57f0b7baf6bcf13c3dbf59f92120ea37c08
Binary files /dev/null and b/img_file/control1.jpg differ
diff --git a/img_file/control2.jpg b/img_file/control2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5111316661ba7070ce2628858045db70ab0f8d2d
Binary files /dev/null and b/img_file/control2.jpg differ
diff --git a/img_file/depth_map.jpg b/img_file/depth_map.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..492799086071fc996ad383a89258c1d68ee95bd2
Binary files /dev/null and b/img_file/depth_map.jpg differ
diff --git a/img_file/l+h+r.jpg b/img_file/l+h+r.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..19d19334afb6983539ee52669fc207130d37eb21
Binary files /dev/null and b/img_file/l+h+r.jpg differ
diff --git a/img_file/l+h+s.jpg b/img_file/l+h+s.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0733c80319bc8cf002ac7bdba77add414daf0257
Binary files /dev/null and b/img_file/l+h+s.jpg differ
diff --git a/img_file/light_map.jpg b/img_file/light_map.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ef410e765fba2b8f403d37cd84b2fe6ba722d615
Binary files /dev/null and b/img_file/light_map.jpg differ
diff --git a/img_file/logo_onerestore.png b/img_file/logo_onerestore.png
new file mode 100644
index 0000000000000000000000000000000000000000..6203260fd896d39ef4eb1931cedf6e6770c3afea
Binary files /dev/null and b/img_file/logo_onerestore.png differ
diff --git a/img_file/metric.png b/img_file/metric.png
new file mode 100644
index 0000000000000000000000000000000000000000..72b6df13be7dbef57f29730dce7bbf5b8f3a259a
Binary files /dev/null and b/img_file/metric.png differ
diff --git a/img_file/metrics_CDD-11_psnr_ssim.xlsx b/img_file/metrics_CDD-11_psnr_ssim.xlsx
new file mode 100644
index 0000000000000000000000000000000000000000..627f539263886ae2306088a3613a6e2d154706f2
Binary files /dev/null and b/img_file/metrics_CDD-11_psnr_ssim.xlsx differ
diff --git a/img_file/pipeline.jpg b/img_file/pipeline.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b5e6a1d37a026712a6d32b285406f5777affc45d
--- /dev/null
+++ b/img_file/pipeline.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:80600ecac7ff326ef3d322ea4db08d27edf3595befdba088deb783ceb260afa3
+size 2722330
diff --git a/img_file/rain_mask.jpg b/img_file/rain_mask.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6bd8872862540001b98997dd018e93a42b9c18d5
Binary files /dev/null and b/img_file/rain_mask.jpg differ
diff --git a/img_file/real.jpg b/img_file/real.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8336b3bc87452a10eb833ff4f65cb6cb25faf0da
--- /dev/null
+++ b/img_file/real.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bdd33e434a8370527d4c8d00603ce7b2b3d981eb5db58c38fe3ec32ed6a73c2d
+size 1158318
diff --git a/img_file/snow_mask.png b/img_file/snow_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..33592b84efc59216714fe434bad1f8034043d400
Binary files /dev/null and b/img_file/snow_mask.png differ
diff --git a/img_file/syn.jpg b/img_file/syn.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0ae63e238b2041149a774e16178f140eb3ddeaa9
Binary files /dev/null and b/img_file/syn.jpg differ
diff --git a/makedataset.py b/makedataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..96f0b8c838b0fb143241c997d5c813e374c5a28d
--- /dev/null
+++ b/makedataset.py
@@ -0,0 +1,157 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Feb 12 20:00:46 2020
+
+@author: Administrator
+"""
+
+import os
+import os.path
+import random
+import numpy as np
+import cv2
+import h5py
+import torch
+import torch.utils.data as udata
+import argparse
+from PIL import Image
+class Dataset(udata.Dataset):
+ r"""Implements torch.utils.data.Dataset
+ """
+ def __init__(self, file, trainrgb=True,trainsyn = True, shuffle=False):
+ super(Dataset, self).__init__()
+ self.trainrgb = trainrgb
+ self.trainsyn = trainsyn
+ self.train_haze = file
+
+ h5f = h5py.File(self.train_haze, 'r')
+
+ self.keys = list(h5f.keys())
+ if shuffle:
+ random.shuffle(self.keys)
+ h5f.close()
+
+ def __len__(self):
+ return len(self.keys)
+
+ def __getitem__(self, index):
+
+ h5f = h5py.File(self.train_haze, 'r')
+
+ key = self.keys[index]
+ data = np.array(h5f[key])
+ h5f.close()
+ return torch.Tensor(data)
+
+def data_augmentation(clear, mode):
+ r"""Performs dat augmentation of the input image
+
+ Args:
+ image: a cv2 (OpenCV) image
+ mode: int. Choice of transformation to apply to the image
+ 0 - no transformation
+ 1 - flip up and down
+ 2 - rotate counterwise 90 degree
+ 3 - rotate 90 degree and flip up and down
+ 4 - rotate 180 degree
+ 5 - rotate 180 degree and flip
+ 6 - rotate 270 degree
+ 7 - rotate 270 degree and flip
+ """
+ clear = np.transpose(clear, (2, 3, 0, 1))
+ if mode == 0:
+ # original
+ clear = clear
+ elif mode == 1:
+ # flip up and down
+ clear = np.flipud(clear)
+ elif mode == 2:
+ # rotate counterwise 90 degree
+ clear = np.rot90(clear)
+ elif mode == 3:
+ # rotate 90 degree and flip up and down
+ clear = np.rot90(clear)
+ clear = np.flipud(clear)
+ elif mode == 4:
+ # rotate 180 degree
+ clear = np.rot90(clear, k=2)
+ elif mode == 5:
+ # rotate 180 degree and flip
+ clear = np.rot90(clear, k=2)
+ clear = np.flipud(clear)
+ elif mode == 6:
+ # rotate 270 degree
+ clear = np.rot90(clear, k=3)
+ elif mode == 7:
+ # rotate 270 degree and flip
+ clear = np.rot90(clear, k=3)
+ clear = np.flipud(clear)
+ else:
+ raise Exception('Invalid choice of image transformation')
+ return np.transpose(clear, (2, 3, 0, 1))
+
+def img_to_patches(img,win,stride,Syn=True):
+ typ, chl, raw, col = img.shape
+ chl = int(chl)
+ num_raw = np.ceil((raw-win)/stride+1).astype(np.uint8)
+ num_col = np.ceil((col-win)/stride+1).astype(np.uint8)
+ count = 0
+ total_process = int(num_col)*int(num_raw)
+ img_patches = np.zeros([typ, chl, win, win, total_process])
+ if Syn:
+ for i in range(num_raw):
+ for j in range(num_col):
+ if stride * i + win <= raw and stride * j + win <=col:
+ img_patches[:,:,:,:,count] = img[:, :, stride*i : stride*i + win, stride*j : stride*j + win]
+ elif stride * i + win > raw and stride * j + win<=col:
+ img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,stride * j : stride * j + win]
+ elif stride * i + win <= raw and stride*j + win>col:
+ img_patches[:,:,:,:,count] = img[:, :,stride*i : stride*i + win, col-win : col]
+ else:
+ img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,col-win : col]
+ img_patches[:,:,:,:,count] = data_augmentation(img_patches[:, :, :, :, count], np.random.randint(0, 7))
+ count +=1
+ return img_patches
+
+def read_img(img):
+ return np.array(Image.open(img))/255.
+
+def Train_data(args):
+ file_list = os.listdir(f'{args.train_path}/{args.gt_name}')
+
+ with h5py.File(args.data_name, 'w') as h5f:
+ count = 0
+ for i in range(len(file_list)):
+ print(file_list[i])
+ img_list = []
+
+ img_list.append(read_img(f'{args.train_path}/{args.gt_name}/{file_list[i]}'))
+ for j in args.degradation_name:
+ img_list.append(read_img(f'{args.train_path}/{j}/{file_list[i]}'))
+
+ img = np.stack(img_list,0)
+ img = img_to_patches(img.transpose(0, 3, 1, 2), args.patch_size, args.stride)
+
+ for nx in range(img.shape[4]):
+ data = img[:,:,:,:,nx]
+ print(count, data.shape)
+ h5f.create_dataset(str(count), data=data)
+ count += 1
+ h5f.close()
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description = "Building the training patch database")
+ parser.add_argument("--patch-size", type = int, default=256, help="Patch size")
+ parser.add_argument("--stride", type = int, default=200, help="Size of stride")
+
+ parser.add_argument("--train-path", type = str, default='./data/CDD-11_train', help="Train path")
+ parser.add_argument("--data-name", type = str, default='dataset.h5', help="Data name")
+
+ parser.add_argument("--gt-name", type = str, default='clear', help="HQ name")
+ parser.add_argument("--degradation-name", type = list, default=['low','haze','rain','snow',\
+ 'low_haze','low_rain','low_snow','haze_rain','haze_snow','low_haze_rain','low_haze_snow'], help="LQ name")
+
+ args = parser.parse_args()
+
+ Train_data(args)
\ No newline at end of file
diff --git a/model/Embedder.py b/model/Embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0d53309e2adcc0b9eb40b196a1c8f0ab3d4c6a0
--- /dev/null
+++ b/model/Embedder.py
@@ -0,0 +1,238 @@
+import numpy as np
+import torch, torchvision
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+from utils.utils_word_embedding import initialize_wordembedding_matrix
+
+class Backbone(nn.Module):
+ def __init__(self, backbone='resnet18'):
+ super(Backbone, self).__init__()
+
+ if backbone == 'resnet18':
+ resnet = torchvision.models.resnet.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
+ elif backbone == 'resnet50':
+ resnet = torchvision.models.resnet.resnet50(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
+ elif backbone == 'resnet101':
+ resnet = torchvision.models.resnet.resnet101(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
+
+ self.block0 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
+ )
+ self.block1 = resnet.layer1
+ self.block2 = resnet.layer2
+ self.block3 = resnet.layer3
+ self.block4 = resnet.layer4
+
+ def forward(self, x, returned=[4]):
+ blocks = [self.block0(x)]
+
+ blocks.append(self.block1(blocks[-1]))
+ blocks.append(self.block2(blocks[-1]))
+ blocks.append(self.block3(blocks[-1]))
+ blocks.append(self.block4(blocks[-1]))
+
+ out = [blocks[i] for i in returned]
+ return out
+
+class CosineClassifier(nn.Module):
+ def __init__(self, temp=0.05):
+ super(CosineClassifier, self).__init__()
+ self.temp = temp
+
+ def forward(self, img, concept, scale=True):
+ """
+ img: (bs, emb_dim)
+ concept: (n_class, emb_dim)
+ """
+ img_norm = F.normalize(img, dim=-1)
+ concept_norm = F.normalize(concept, dim=-1)
+ pred = torch.matmul(img_norm, concept_norm.transpose(0, 1))
+ if scale:
+ pred = pred / self.temp
+ return pred
+
+class Embedder(nn.Module):
+ """
+ Text and Visual Embedding Model.
+ """
+ def __init__(self,
+ type_name,
+ feat_dim = 512,
+ mid_dim = 1024,
+ out_dim = 324,
+ drop_rate = 0.35,
+ cosine_cls_temp = 0.05,
+ wordembs = 'glove',
+ extractor_name = 'resnet18'):
+ super(Embedder, self).__init__()
+
+ mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+ self.type_name = type_name
+ self.feat_dim = feat_dim
+ self.mid_dim = mid_dim
+ self.out_dim = out_dim
+ self.drop_rate = drop_rate
+ self.cosine_cls_temp = cosine_cls_temp
+ self.wordembs = wordembs
+ self.extractor_name = extractor_name
+ self.transform = transforms.Normalize(mean, std)
+
+ self._setup_word_embedding()
+ self._setup_image_embedding()
+
+ def _setup_image_embedding(self):
+ # image embedding
+ self.feat_extractor = Backbone(self.extractor_name)
+
+ img_emb_modules = [
+ nn.Conv2d(self.feat_dim, self.mid_dim, kernel_size=1, bias=False),
+ nn.BatchNorm2d(self.mid_dim),
+ nn.ReLU()
+ ]
+ if self.drop_rate > 0:
+ img_emb_modules += [nn.Dropout2d(self.drop_rate)]
+ self.img_embedder = nn.Sequential(*img_emb_modules)
+
+ self.img_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ self.img_final = nn.Linear(self.mid_dim, self.out_dim)
+
+ self.classifier = CosineClassifier(temp=self.cosine_cls_temp)
+
+ def _setup_word_embedding(self):
+
+ self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))}
+ self.num_type = len(self.type_name)
+ train_type = [self.type2idx[type_i] for type_i in self.type_name]
+ self.train_type = torch.LongTensor(train_type).to("cuda" if torch.cuda.is_available() else "cpu")
+
+ wordemb, self.word_dim = \
+ initialize_wordembedding_matrix(self.wordembs, self.type_name)
+
+ self.embedder = nn.Embedding(self.num_type, self.word_dim)
+ self.embedder.weight.data.copy_(wordemb)
+
+ self.mlp = nn.Sequential(
+ nn.Linear(self.word_dim, self.out_dim),
+ nn.ReLU(True)
+ )
+
+ def train_forward(self, batch):
+
+ scene, img = batch[0], self.transform(batch[1])
+ bs = img.shape[0]
+
+ # word embedding
+ scene_emb = self.embedder(self.train_type)
+ scene_weight = self.mlp(scene_emb)
+
+ #image embedding
+ img = self.feat_extractor(img)[0]
+ img = self.img_embedder(img)
+ img = self.img_avg_pool(img).squeeze(3).squeeze(2)
+ img = self.img_final(img)
+
+ pred = self.classifier(img, scene_weight)
+ label_loss = F.cross_entropy(pred, scene)
+ pred = torch.max(pred, dim=1)[1]
+ type_pred = self.train_type[pred]
+ correct_type = (type_pred == scene)
+ out = {
+ 'loss_total': label_loss,
+ 'acc_type': torch.div(correct_type.sum(),float(bs)),
+ }
+
+ return out
+
+ def image_encoder_forward(self, batch):
+ img = self.transform(batch)
+
+ # word embedding
+ scene_emb = self.embedder(self.train_type)
+ scene_weight = self.mlp(scene_emb)
+
+ #image embedding
+ img = self.feat_extractor(img)[0]
+ bs, _, h, w = img.shape
+ img = self.img_embedder(img)
+ img = self.img_avg_pool(img).squeeze(3).squeeze(2)
+ img = self.img_final(img)
+
+ pred = self.classifier(img, scene_weight)
+ pred = torch.max(pred, dim=1)[1]
+
+ out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
+ for i in range(bs):
+ out_embedding[i,:] = scene_weight[pred[i],:]
+ num_type = self.train_type[pred]
+ text_type = [self.type_name[num_type[i]] for i in range(bs)]
+
+ return out_embedding, num_type, text_type
+
+ def text_encoder_forward(self, text):
+
+ bs = len(text)
+
+ # word embedding
+ scene_emb = self.embedder(self.train_type)
+ scene_weight = self.mlp(scene_emb)
+
+ num_type = torch.zeros((bs)).to("cuda" if torch.cuda.is_available() else "cpu")
+ for i in range(bs):
+ num_type[i] = self.type2idx[text[i]]
+
+ out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
+ for i in range(bs):
+ out_embedding[i,:] = scene_weight[int(num_type[i]),:]
+ text_type = text
+
+ return out_embedding, num_type, text_type
+
+ def text_idx_encoder_forward(self, idx):
+
+ bs = idx.shape[0]
+
+ # word embedding
+ scene_emb = self.embedder(self.train_type)
+ scene_weight = self.mlp(scene_emb)
+
+ num_type = idx
+
+ out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
+ for i in range(bs):
+ out_embedding[i,:] = scene_weight[int(num_type[i]),:]
+
+ return out_embedding
+
+ def contrast_loss_forward(self, batch):
+
+ img = self.transform(batch)
+
+ #image embedding
+ img = self.feat_extractor(img)[0]
+ img = self.img_embedder(img)
+ img = self.img_avg_pool(img).squeeze(3).squeeze(2)
+ img = self.img_final(img)
+
+ return img
+
+ def forward(self, x, type = 'image_encoder'):
+
+ if type == 'train':
+ out = self.train_forward(x)
+
+ elif type == 'image_encoder':
+ with torch.no_grad():
+ out = self.image_encoder_forward(x)
+
+ elif type == 'text_encoder':
+ out = self.text_encoder_forward(x)
+
+ elif type == 'text_idx_encoder':
+ out = self.text_idx_encoder_forward(x)
+
+ elif type == 'visual_embed':
+ x = F.interpolate(x,size=(224,224),mode='bilinear')
+ out = self.contrast_loss_forward(x)
+
+ return out
\ No newline at end of file
diff --git a/model/OneRestore.py b/model/OneRestore.py
new file mode 100644
index 0000000000000000000000000000000000000000..aea7f6cccefcc2f09209b2ba4124b19f2e68296d
--- /dev/null
+++ b/model/OneRestore.py
@@ -0,0 +1,314 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Sun Jun 20 16:14:37 2021
+
+@author: Administrator
+"""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from torchvision import transforms
+import torch, math
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+import numbers
+
+from thop import profile
+import numpy as np
+import time
+from torchvision import transforms
+
+
+class OneRestore(nn.Module):
+ def __init__(self, channel = 32):
+ super(OneRestore,self).__init__()
+ self.norm = lambda x: (x-0.5)/0.5
+ self.denorm = lambda x: (x+1)/2
+ self.in_conv = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False)
+ self.encoder = encoder(channel)
+ self.middle = backbone(channel)
+ self.decoder = decoder(channel)
+ self.out_conv = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False)
+
+ def forward(self,x,embedding):
+ x_in = self.in_conv(self.norm(x))
+ x_l, x_m, x_s, x_ss = self.encoder(x_in, embedding)
+ x_mid = self.middle(x_ss, embedding)
+ x_out = self.decoder(x_mid, x_ss, x_s, x_m, x_l, embedding)
+ out = self.out_conv(x_out) + x
+ return self.denorm(out)
+
+class encoder(nn.Module):
+ def __init__(self,channel):
+ super(encoder,self).__init__()
+
+ self.el = ResidualBlock(channel)#16
+ self.em = ResidualBlock(channel*2)#32
+ self.es = ResidualBlock(channel*4)#64
+ self.ess = ResidualBlock(channel*8)#128
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
+ self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32
+ self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64
+ self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128
+ self.conv_esstesss = nn.Conv2d(8*channel,16*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 256
+
+ def forward(self,x,embedding):
+
+ elout = self.el(x, embedding)#16
+ x_emin = self.conv_eltem(self.maxpool(elout))#32
+ emout = self.em(x_emin, embedding)
+ x_esin = self.conv_emtes(self.maxpool(emout))
+ esout = self.es(x_esin, embedding)
+ x_esin = self.conv_estess(self.maxpool(esout))
+ essout = self.ess(x_esin, embedding)#128
+
+ return elout, emout, esout, essout#,esssout
+
+class backbone(nn.Module):
+ def __init__(self,channel):
+ super(backbone,self).__init__()
+
+ self.s1 = ResidualBlock(channel*8)#128
+ self.s2 = ResidualBlock(channel*8)#128
+
+ def forward(self,x,embedding):
+
+ share1 = self.s1(x, embedding)
+ share2 = self.s2(share1, embedding)
+
+ return share2
+
+class decoder(nn.Module):
+ def __init__(self,channel):
+ super(decoder,self).__init__()
+
+ self.dss = ResidualBlock(channel*8)#128
+ self.ds = ResidualBlock(channel*4)#64
+ self.dm = ResidualBlock(channel*2)#32
+ self.dl = ResidualBlock(channel)#16
+
+ #self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128
+ self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64
+ self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32
+ self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16
+
+ def _upsample(self,x,y):
+ _,_,H0,W0 = y.size()
+ return F.interpolate(x,size=(H0,W0),mode='bilinear')
+
+ def forward(self, x, x_ss, x_s, x_m, x_l, embedding):
+
+ dssout = self.dss(x + x_ss, embedding)
+ x_dsin = self.conv_dsstds(self._upsample(dssout, x_s))
+ dsout = self.ds(x_dsin + x_s, embedding)
+ x_dmin = self.conv_dstdm(self._upsample(dsout, x_m))
+ dmout = self.dm(x_dmin + x_m, embedding)
+ x_dlin = self.conv_dmtdl(self._upsample(dmout, x_l))
+ dlout = self.dl(x_dlin + x_l, embedding)
+
+ return dlout
+
+
+class ResidualBlock(nn.Module): # Edge-oriented Residual Convolution Block 面向边缘的残差网络块 解决梯度消失的问题
+ def __init__(self, channel, norm=False):
+ super(ResidualBlock, self).__init__()
+
+ self.el = TransformerBlock(channel, num_heads=8, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
+
+ def forward(self, x,embedding):
+ return self.el(x,embedding)
+
+def to_3d(x):
+ return rearrange(x, 'b c h w -> b (h w) c')
+
+def to_4d(x, h, w):
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
+
+
+class BiasFree_LayerNorm(nn.Module):
+ def __init__(self, normalized_shape):
+ super(BiasFree_LayerNorm, self).__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ normalized_shape = torch.Size(normalized_shape)
+ assert len(normalized_shape) == 1
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.normalized_shape = normalized_shape
+
+ def forward(self, x):
+ sigma = x.var(-1, keepdim=True, unbiased=False)
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
+
+class WithBias_LayerNorm(nn.Module):
+ def __init__(self, normalized_shape):
+ super(WithBias_LayerNorm, self).__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ normalized_shape = torch.Size(normalized_shape)
+ assert len(normalized_shape) == 1
+
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.normalized_shape = normalized_shape
+
+ def forward(self, x):
+ mu = x.mean(-1, keepdim=True)
+ sigma = x.var(-1, keepdim=True, unbiased=False)
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
+
+class LayerNorm(nn.Module):
+ def __init__(self, dim, LayerNorm_type):
+ super(LayerNorm, self).__init__()
+ if LayerNorm_type == 'BiasFree':
+ self.body = BiasFree_LayerNorm(dim)
+ else:
+ self.body = WithBias_LayerNorm(dim)
+
+ def forward(self, x):
+ h, w = x.shape[-2:]
+ return to_4d(self.body(to_3d(x)), h, w)
+
+class Cross_Attention(nn.Module):
+ def __init__(self,
+ dim,
+ num_heads,
+ bias,
+ q_dim = 324):
+ super(Cross_Attention, self).__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ sqrt_q_dim = int(math.sqrt(q_dim))
+ self.resize = transforms.Resize([sqrt_q_dim, sqrt_q_dim])
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
+
+ self.q = nn.Linear(q_dim, q_dim, bias=bias)
+
+ self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
+ self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
+
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
+ def forward(self, x, query):
+ b,c,h,w = x.shape
+
+ q = self.q(query)
+ k, v = self.kv_dwconv(self.kv(x)).chunk(2, dim=1)
+ k = self.resize(k)
+
+ q = repeat(q, 'b l -> b head c l', head=self.num_heads, c=self.dim//self.num_heads)
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
+
+ q = torch.nn.functional.normalize(q, dim=-1)
+ k = torch.nn.functional.normalize(k, dim=-1)
+
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
+ attn = attn.softmax(dim=-1)
+
+ out = (attn @ v)
+
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
+
+ out = self.project_out(out)
+ return out
+
+class Self_Attention(nn.Module):
+ def __init__(self,
+ dim,
+ num_heads,
+ bias):
+ super(Self_Attention, self).__init__()
+ self.num_heads = num_heads
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
+
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
+ def forward(self, x):
+ b,c,h,w = x.shape
+
+ qkv = self.qkv_dwconv(self.qkv(x))
+ q,k,v = qkv.chunk(3, dim=1)
+
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
+
+ q = torch.nn.functional.normalize(q, dim=-1)
+ k = torch.nn.functional.normalize(k, dim=-1)
+
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
+ attn = attn.softmax(dim=-1)
+
+ out = (attn @ v)
+
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
+
+ out = self.project_out(out)
+ return out
+
+class FeedForward(nn.Module):
+ def __init__(self,
+ dim,
+ ffn_expansion_factor,
+ bias):
+ super(FeedForward, self).__init__()
+
+ hidden_features = int(dim * ffn_expansion_factor)
+
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
+
+ self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
+ groups=hidden_features * 2, bias=bias)
+
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
+
+ def forward(self, x):
+ x = self.project_in(x)
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
+ x = F.gelu(x1) * x2
+ x = self.project_out(x)
+ return x
+
+class TransformerBlock(nn.Module):
+ def __init__(self,
+ dim,
+ num_heads=8,
+ ffn_expansion_factor=2.66,
+ bias=False,
+ LayerNorm_type='WithBias'):
+ super(TransformerBlock, self).__init__()
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
+ self.cross_attn = Cross_Attention(dim, num_heads, bias)
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
+ self.self_attn = Self_Attention(dim, num_heads, bias)
+ self.norm3 = LayerNorm(dim, LayerNorm_type)
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
+
+ def forward(self, x, query):
+ x = x + self.cross_attn(self.norm1(x),query)
+ x = x + self.self_attn(self.norm2(x))
+ x = x + self.ffn(self.norm3(x))
+ return x
+
+if __name__ == '__main__':
+ net = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
+ # x = torch.Tensor(np.random.random((2,3,256,256))).to("cuda" if torch.cuda.is_available() else "cpu")
+ # query = torch.Tensor(np.random.random((2, 324))).to("cuda" if torch.cuda.is_available() else "cpu")
+ # out = net(x, query)
+ # print(out.shape)
+ input = torch.randn(1, 3, 512, 512).to("cuda" if torch.cuda.is_available() else "cpu")
+ query = torch.Tensor(np.random.random((1, 324))).to("cuda" if torch.cuda.is_available() else "cpu")
+ macs, _ = profile(net, inputs=(input, query))
+ total = sum([param.nelement() for param in net.parameters()])
+ print('Macs = ' + str(macs/1000**3) + 'G')
+ print('Params = ' + str(total/1e6) + 'M')
+
+ from fvcore.nn import FlopCountAnalysis, parameter_count_table
+ flops = FlopCountAnalysis(net, (input, query))
+ print("FLOPs", flops.total()/1000**3)
+
+
diff --git a/model/loss.py b/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf3cab077699c71d3248151fd6ae366e711a4fa
--- /dev/null
+++ b/model/loss.py
@@ -0,0 +1,222 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import torch.nn.functional as F
+import cv2 as cv
+import numpy as np
+from matplotlib import pyplot as plt
+from math import exp
+from torchvision import transforms
+from torchvision.models import vgg16
+import torchvision
+'''
+MS-SSIM Loss
+'''
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
+ return gauss/gauss.sum()
+
+
+def create_window(window_size, channel=1):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
+ return window
+
+
+def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
+ if val_range is None:
+ if torch.max(img1) > 128:
+ max_val = 255
+ else:
+ max_val = 1
+
+ if torch.min(img1) < -0.5:
+ min_val = -1
+ else:
+ min_val = 0
+ L = max_val - min_val
+ else:
+ L = val_range
+
+ padd = 0
+ (_, channel, height, width) = img1.size()
+ if window is None:
+ real_size = min(window_size, height, width)
+ window = create_window(real_size, channel=channel).to(img1.device)
+
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
+
+ C1 = (0.01 * L) ** 2
+ C2 = (0.03 * L) ** 2
+
+ v1 = 2.0 * sigma12 + C2
+ v2 = sigma1_sq + sigma2_sq + C2
+ cs = torch.mean(v1 / v2) # contrast sensitivity
+
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
+
+ if size_average:
+ ret = ssim_map.mean()
+ else:
+ ret = ssim_map.mean(1).mean(1).mean(1)
+
+ if full:
+ return ret, cs
+ return ret
+
+
+def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(img1.device)
+ levels = weights.size()[0]
+ mssim = []
+ mcs = []
+ for _ in range(levels):
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
+ mssim.append(sim)
+ mcs.append(cs)
+
+ img1 = F.avg_pool2d(img1, (2, 2))
+ img2 = F.avg_pool2d(img2, (2, 2))
+
+ mssim = torch.stack(mssim)
+ mcs = torch.stack(mcs)
+
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
+ if normalize:
+ mssim = (mssim + 1) / 2
+ mcs = (mcs + 1) / 2
+
+ pow1 = mcs ** weights
+ pow2 = mssim ** weights
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
+ output = torch.prod(pow1[:-1] * pow2[-1])
+ return output
+
+
+# Classes to re-use window
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True, val_range=None):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.val_range = val_range
+
+ # Assume 1 channel for SSIM
+ self.channel = 1
+ self.window = create_window(window_size)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.window.dtype == img1.dtype:
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
+ self.window = window
+ self.channel = channel
+
+ return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
+
+class MSSSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True, channel=3):
+ super(MSSSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = channel
+
+ def forward(self, img1, img2):
+ # TODO: store window between calls if possible
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
+
+class TVLoss(nn.Module):
+ def __init__(self,TVLoss_weight=1):
+ super(TVLoss,self).__init__()
+ self.TVLoss_weight = TVLoss_weight
+
+ def forward(self,x):
+ batch_size = x.size()[0]
+ h_x = x.size()[2]
+ w_x = x.size()[3]
+ count_h = self._tensor_size(x[:,:,1:,:]) #算出总共求了多少次差
+ count_w = self._tensor_size(x[:,:,:,1:])
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
+ # x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片
+ # 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个
+ # 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相
+ # 邻的下一个像素点的差。
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
+ return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
+
+ def _tensor_size(self,t):
+ return t.size()[1]*t.size()[2]*t.size()[3]
+
+ def _tensor_size(self,t):
+ return t.size()[1]*t.size()[2]*t.size()[3]
+
+class ContrastLoss(nn.Module):
+ def __init__(self):
+ super(ContrastLoss, self).__init__()
+ self.l1 = nn.L1Loss()
+ self.model = vgg16(weights = torchvision.models.VGG16_Weights.DEFAULT)
+ self.model = self.model.features[:16].to("cuda" if torch.cuda.is_available() else "cpu")
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.layer_name_mapping = {
+ '3': "relu1_2",
+ '8': "relu2_2",
+ '15': "relu3_3"
+ }
+
+ def gen_features(self, x):
+ output = []
+ for name, module in self.model._modules.items():
+ x = module(x)
+ if name in self.layer_name_mapping:
+ output.append(x)
+ return output
+ def forward(self, inp, pos, neg, out):
+ inp_t = inp
+ inp_x0 = self.gen_features(inp_t)
+ pos_t = pos
+ pos_x0 = self.gen_features(pos_t)
+ out_t = out
+ out_x0 = self.gen_features(out_t)
+ neg_t, neg_x0 = [],[]
+ for i in range(neg.shape[1]):
+ neg_i = neg[:,i,:,:]
+ neg_t.append(neg_i)
+ neg_x0_i = self.gen_features(neg_i)
+ neg_x0.append(neg_x0_i)
+ loss = 0
+ for i in range(len(pos_x0)):
+ pos_term = self.l1(out_x0[i], pos_x0[i].detach())
+ inp_term = self.l1(out_x0[i], inp_x0[i].detach())/(len(neg_x0)+1)
+ neg_term = sum(self.l1(out_x0[i], neg_x0[j][i].detach()) for j in range(len(neg_x0)))/(len(neg_x0)+1)
+ loss = loss + pos_term / (inp_term+neg_term+1e-7)
+ return loss / len(pos_x0)
+
+class Total_loss(nn.Module):
+ def __init__(self, args):
+ super(Total_loss, self).__init__()
+ self.con_loss = ContrastLoss()
+ self.weight_sl1, self.weight_msssim, self.weight_drl = args.loss_weight
+
+ def forward(self, inp, pos, neg, out):
+ smooth_loss_l1 = F.smooth_l1_loss(out, pos)
+ msssim_loss = 1-msssim(out, pos, normalize=True)
+ c_loss = self.con_loss(inp[0], pos, neg, out)
+
+ total_loss = self.weight_sl1 * smooth_loss_l1 + self.weight_msssim * msssim_loss + self.weight_drl * c_loss
+ return total_loss
\ No newline at end of file
diff --git a/output/low_haze_rain_00469_01_lq.png b/output/low_haze_rain_00469_01_lq.png
new file mode 100644
index 0000000000000000000000000000000000000000..319ed4a3ca7a5f33d4d4f6905c494d85af310f72
--- /dev/null
+++ b/output/low_haze_rain_00469_01_lq.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44f9a861af1f4672c1799e6c3cf20ca2759522dae5f78d9fe8b4540eefb206f6
+size 1439636
diff --git a/output/low_haze_snow_00337_01_lq.png b/output/low_haze_snow_00337_01_lq.png
new file mode 100644
index 0000000000000000000000000000000000000000..c51a4acf9b97d94f35754bf29a9544637fb6b269
--- /dev/null
+++ b/output/low_haze_snow_00337_01_lq.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d007771f4541535683819733631de220abf76a3afbe01e39a69878477d3736b7
+size 1160589
diff --git a/remove_optim.py b/remove_optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c794e059651185564ffb765da818491da4aa61a
--- /dev/null
+++ b/remove_optim.py
@@ -0,0 +1,32 @@
+import torch, argparse
+from model.OneRestore import OneRestore
+from model.Embedder import Embedder
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument("--type", type=str, default = 'OneRestore')
+parser.add_argument("--input-file", type=str, default = './ckpts/onerestore_cdd-11.tar')
+parser.add_argument("--output-file", type=str, default = './ckpts/onerestore_cdd-11.tar')
+
+args = parser.parse_args()
+
+if args.type == 'OneRestore':
+ restorer = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
+ restorer_info = torch.load(args.input_file, map_location='cuda:0')
+ weights_dict = {}
+ for k, v in restorer_info['state_dict'].items():
+ new_k = k.replace('module.', '') if 'module' in k else k
+ weights_dict[new_k] = v
+ restorer.load_state_dict(weights_dict)
+ torch.save(restorer.state_dict(), args.output_file)
+elif args.type == 'Embedder':
+ combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow']
+ embedder = Embedder(combine_type).to("cuda" if torch.cuda.is_available() else "cpu")
+ embedder_info = torch.load(args.input_file)
+ embedder.load_state_dict(embedder_info['state_dict'])
+ torch.save(embedder.state_dict(), args.output_file)
+else:
+ print('ERROR!')
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..86c904e2e86b620ee157a40d3168432e7ffd516d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+pillow
+numpy
+scikit-image
+pandas
+einops
+thop
+fasttext
+opencv-python
+h5py
+matplotlib
diff --git a/syn_data/data/clear/1.jpg b/syn_data/data/clear/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8b7283c9a2fe7bc762dee7cf384318669cb1499f
Binary files /dev/null and b/syn_data/data/clear/1.jpg differ
diff --git a/syn_data/data/depth_map/1.jpg b/syn_data/data/depth_map/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ef410e765fba2b8f403d37cd84b2fe6ba722d615
Binary files /dev/null and b/syn_data/data/depth_map/1.jpg differ
diff --git a/syn_data/data/light_map/1.jpg b/syn_data/data/light_map/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..492799086071fc996ad383a89258c1d68ee95bd2
Binary files /dev/null and b/syn_data/data/light_map/1.jpg differ
diff --git a/syn_data/data/rain_mask/00001.jpg b/syn_data/data/rain_mask/00001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6bd8872862540001b98997dd018e93a42b9c18d5
Binary files /dev/null and b/syn_data/data/rain_mask/00001.jpg differ
diff --git a/syn_data/data/rain_mask/00002.jpg b/syn_data/data/rain_mask/00002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..637c95d406ec642639ad577b84227b00295d8a9d
Binary files /dev/null and b/syn_data/data/rain_mask/00002.jpg differ
diff --git a/syn_data/data/rain_mask/00003.jpg b/syn_data/data/rain_mask/00003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..15a1d2224cd6db54d8022557bff9f2eaa839ed59
Binary files /dev/null and b/syn_data/data/rain_mask/00003.jpg differ
diff --git a/syn_data/data/snow_mask/beautiful_smile_00001.jpg b/syn_data/data/snow_mask/beautiful_smile_00001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..44098e14ab9422df45940c42251254b7a590b1ce
Binary files /dev/null and b/syn_data/data/snow_mask/beautiful_smile_00001.jpg differ
diff --git a/syn_data/data/snow_mask/beautiful_smile_00006.jpg b/syn_data/data/snow_mask/beautiful_smile_00006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0888d9a1a026a0a2781582b4351de6a64d71aa01
Binary files /dev/null and b/syn_data/data/snow_mask/beautiful_smile_00006.jpg differ
diff --git a/syn_data/data/snow_mask/beautiful_smile_00008.jpg b/syn_data/data/snow_mask/beautiful_smile_00008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a8b18f20a6678c366c3aab724efa5f377e7c6956
Binary files /dev/null and b/syn_data/data/snow_mask/beautiful_smile_00008.jpg differ
diff --git a/syn_data/out/1.jpg b/syn_data/out/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..19d19334afb6983539ee52669fc207130d37eb21
Binary files /dev/null and b/syn_data/out/1.jpg differ
diff --git a/syn_data/syn_data.py b/syn_data/syn_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef5b830e34d21e4bd7fc45a088c79850b3e90f2d
--- /dev/null
+++ b/syn_data/syn_data.py
@@ -0,0 +1,86 @@
+import os, argparse, cv2, random
+import numpy as np
+from skimage import exposure
+
+def guideFilter(I, p, winSize, eps):
+ mean_I = cv2.blur(I, winSize)
+ mean_p = cv2.blur(p, winSize)
+ mean_II = cv2.blur(I * I, winSize)
+ mean_Ip = cv2.blur(I * p, winSize)
+ var_I = mean_II - mean_I * mean_I
+ cov_Ip = mean_Ip - mean_I * mean_p
+ a = cov_Ip / (var_I + eps)
+ b = mean_p - a * mean_I
+ mean_a = cv2.blur(a, winSize)
+ mean_b = cv2.blur(b, winSize)
+ q = mean_a * I + mean_b
+ return q
+
+def syn_low(img, light, img_gray, light_max=3,
+ light_min=2, noise_max=0.08, noise_min=0.03):
+ light = guideFilter(light, img_gray,(3,3),0.01)[:, :, np.newaxis]
+ n = np.random.uniform(noise_min, noise_max)
+ R = img / (light + 1e-7)
+ L = (light + 1e-7) ** np.random.uniform(light_min, light_max)
+ return np.clip(R * L + np.random.normal(0, n, img.shape), 0, 1)
+
+def syn_haze(img, depth, beta_max=2.0, beta_min=1.0, A_max=0.9, A_min=0.6,
+ color_max=0, color_min=0):
+ beta = np.random.rand(1) * (beta_max - beta_min) + beta_min
+ t = np.exp(-np.minimum(1 - cv2.blur(depth,(22,22)),0.7) * beta)
+ A = np.random.rand(1) * (A_max - A_min) + A_min
+ A_random = np.random.rand(3) * (color_max - color_min) + color_min
+ A = A + A_random
+ return np.clip(img * t + A * (1 - t), 0, 1)
+
+def syn_data(hq_file, light_file, depth_file, rain_file, snow_file, out_file,
+ low, haze, rain, snow):
+ file_list = os.listdir(hq_file)
+ rain_list = os.listdir(rain_file)
+ snow_list = os.listdir(snow_file)
+ num_rain = random.sample(range(0,len(rain_list)),len(rain_list))
+ num_snow = random.sample(range(0,len(snow_list)),len(snow_list))
+ for i in range(1, len(file_list)):
+ img = cv2.imread(hq_file+file_list[i])
+ w, h, _ = img.shape
+ light = cv2.cvtColor(cv2.imread(light_file + file_list[i]), cv2.COLOR_RGB2GRAY) / 255.0
+ depth = cv2.imread(depth_file + file_list[i]) / 255.0
+ rain_mask = cv2.imread(rain_file + rain_list[num_rain[i]]) / 255.0
+ rain_mask = cv2.resize(rain_mask,(h,w))
+ snow_mask = cv2.imread(snow_file + snow_list[num_snow[i]]) / 255.0
+ snow_mask = cv2.resize(snow_mask, (h, w))
+ img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)/ 255.0
+ lq = img.copy()/255.0
+ color_dis = 1
+
+ if low:
+ lq = syn_low(lq, light, img_gray)
+ if rain:
+ lq = lq+rain_mask
+ if snow:
+ lq = lq*(1-snow_mask)+color_dis*snow_mask
+ if haze:
+ lq = syn_haze(lq, depth)
+
+ # out = np.concatenate((lq*255.0,img),1)
+ out = lq*255.0
+ cv2.imwrite(out_file + file_list[i], out)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # load model
+ parser.add_argument("--hq-file", type=str, default = './data/clear/')
+ parser.add_argument("--light-file", type=str, default = './data/light_map/')
+ parser.add_argument("--depth-file", type=str, default = './data/depth_map/')
+ parser.add_argument("--rain-file", type=str, default = './data/rain_mask/')
+ parser.add_argument("--snow-file", type=str, default = './data/snow_mask/')
+ parser.add_argument("--out-file", type=str, default = './out/')
+ parser.add_argument("--low", action='store_true')
+ parser.add_argument("--haze", action='store_true')
+ parser.add_argument("--rain", action='store_true')
+ parser.add_argument("--snow", action='store_true')
+
+ args = parser.parse_args()
+
+ syn_data(args.hq_file, args.light_file, args.depth_file, args.rain_file,
+ args.snow_file, args.out_file, args.low, args.haze, args.rain, args.snow)
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a94eb8bdd5d425d7cacf877fc0b8d36b1c4d9a04
--- /dev/null
+++ b/test.py
@@ -0,0 +1,82 @@
+import os, time, argparse
+from PIL import Image
+import numpy as np
+
+
+import torch
+from torchvision import transforms
+
+from torchvision.utils import save_image as imwrite
+from utils.utils import print_args, load_restore_ckpt, load_embedder_ckpt
+
+transform_resize = transforms.Compose([
+ transforms.Resize([224,224]),
+ transforms.ToTensor()
+ ])
+
+def main(args):
+
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ #train
+ print('> Model Initialization...')
+
+ embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
+ restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=args.restore_model_path)
+
+ os.makedirs(args.output,exist_ok=True)
+
+ files = os.listdir(argspar.input)
+ time_record = []
+ for i in files:
+ lq = Image.open(f'{argspar.input}/{i}')
+
+ with torch.no_grad():
+ lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+ lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+
+ start_time = time.time()
+
+ if args.prompt == None:
+ text_embedding, _, [text] = embedder(lq_em,'image_encoder')
+ print(f'This is {text} degradation estimated by visual embedder.')
+ else:
+ text_embedding, _, [text] = embedder([args.prompt],'text_encoder')
+ print(f'This is {text} degradation generated by input text.')
+
+ out = restorer(lq_re, text_embedding)
+
+ run_time = time.time()-start_time
+ time_record.append(run_time)
+
+ if args.concat:
+ out = torch.cat((lq_re, out), dim=3)
+
+ imwrite(out, f'{args.output}/{i}', range=(0, 1))
+
+ print(f'{i} Running Time: {run_time:.4f}.')
+ print(f'Average time is {np.mean(np.array(run_time))}')
+
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description = "OneRestore Running")
+
+ # load model
+ parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
+ parser.add_argument("--restore-model-path", type=str, default = "./ckpts/onerestore_cdd-11.tar", help = 'restore model path')
+
+ # select model automatic (prompt=False) or manual (prompt=True, text={'clear', 'low', 'haze', 'rain', 'snow',\
+ # 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'})
+ parser.add_argument("--prompt", type=str, default = None, help = 'prompt')
+
+ parser.add_argument("--input", type=str, default = "./image/", help = 'image path')
+ parser.add_argument("--output", type=str, default = "./output/", help = 'output path')
+ parser.add_argument("--concat", action='store_true', help = 'output path')
+
+ argspar = parser.parse_args()
+
+ print_args(argspar)
+
+ main(argspar)
diff --git a/train_Embedder.py b/train_Embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff7a89dd9a9bb2285399aaa080333c17b7ad8bd3
--- /dev/null
+++ b/train_Embedder.py
@@ -0,0 +1,104 @@
+import argparse, os, torch, time
+import torch.optim
+
+from utils.utils import load_embedder_ckpt_with_optim, adjust_learning_rate, freeze_text_embedder, AverageMeter
+from utils.utils_data import init_embedding_data
+
+
+
+def train_embedding(cur_epoch, model, optimizer, trainloader, testloader, device, cfg_em):
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.enabled = True
+
+ acc_train_meter = AverageMeter()
+ acc_test_meter = AverageMeter()
+ loss_train_meter = AverageMeter()
+ loss_test_meter = AverageMeter()
+ time_train_meter = AverageMeter()
+ time_test_meter = AverageMeter()
+
+ freeze_text_embedder(model)
+ for k,v in model.named_parameters():
+ print('{}: {}'.format(k, v.requires_grad))
+ for epoch in range(cur_epoch, cfg_em.epoch+1):
+
+ optimizer = adjust_learning_rate(optimizer, epoch-1, cfg_em.lr_decay)
+ lr = optimizer.param_groups[-1]['lr']
+
+ model.train()
+ for idx, batch in enumerate(trainloader):
+ for i in range(len(batch)):
+ batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")
+ time_start = time.time()
+ out = model(batch, 'train')
+ loss = out['loss_total']
+ acc = out['acc_type']
+ time_train_meter.update(time.time() - time_start)
+
+ acc_train_meter.update(acc)
+ loss_train_meter.update(loss)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ print(f'Epoch:{epoch}|Iter:{idx+1}/{len(trainloader)}|lr:{lr},'
+ f'Loss: {loss_train_meter.avg:.3f},'
+ f'Acc: {acc_train_meter.avg:.3f},'
+ f'Time: {time_train_meter.avg:.3f},', flush=True)
+
+ model.eval()
+ for idx, batch in enumerate(testloader):
+ for i in range(len(batch)):
+ batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")
+
+ time_start = time.time()
+ out = model(batch, 'train')
+ loss = out['loss_total']
+ acc = out['acc_type']
+ time_test_meter.update(time.time() - time_start)
+
+ acc_test_meter.update(acc)
+ loss_test_meter.update(loss)
+ print(f'Epoch:{epoch}|Iter:{idx+1}/{len(testloader)}|lr:{lr},'
+ f'Loss: {loss_test_meter.avg:.3f},'
+ f'Acc: {acc_test_meter.avg:.3f},'
+ f'Time: {time_test_meter.avg:.3f},', flush=True)
+
+ torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()},
+ f'{cfg_em.check_dir}/embedder_model_epoch{epoch}_{acc_train_meter.avg:.3f}_{loss_train_meter.avg:.3f}_{acc_test_meter.avg:.3f}_{loss_test_meter.avg:.3f}.tar')
+ acc_train_meter.reset()
+ acc_test_meter.reset()
+ loss_train_meter.reset()
+ loss_test_meter.reset()
+ time_train_meter.reset()
+ time_test_meter.reset()
+ print('Done!')
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # load model
+ parser.add_argument("--seed", type=int, default = 124)
+ parser.add_argument("--pre_weight", type=str, default = '')
+ parser.add_argument("--lr", type=float, default = 0.0001)
+ parser.add_argument("--type_name", type=list, default = ['clear', 'low', 'haze', 'rain',\
+ 'snow', 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow'])
+ parser.add_argument("--train-dir", type=str, default = './data/CDD-11_train/')
+ parser.add_argument("--test-dir", type=str, default = './data/CDD-11_test/')
+ parser.add_argument("--batch", type=int, default = 128)
+ parser.add_argument("--num-workers", type=int, default = 0)
+ parser.add_argument("--epoch", type=int, default = 200)
+ parser.add_argument("--lr-decay", type=int, default = 50)
+ parser.add_argument("--check-dir", type=str, default = "./ckpts")
+
+ args = parser.parse_args()
+
+ os.makedirs(args.check_dir,exist_ok=True)
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ embedder, optimizer, cur_epoch, device = load_embedder_ckpt_with_optim(device, args)
+ trainloader, testloader = init_embedding_data(args, 'train')
+ train_embedding(cur_epoch, embedder, optimizer, trainloader, testloader, device, args)
diff --git a/train_OneRestore_multi-gpu.py b/train_OneRestore_multi-gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..a37bcb715c71235bd9cb3a4b50f05b719408fd3d
--- /dev/null
+++ b/train_OneRestore_multi-gpu.py
@@ -0,0 +1,153 @@
+import os, time, torch, argparse
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torchvision.utils import save_image as imwrite
+import numpy as np
+from torchvision import transforms
+from makedataset import Dataset
+from utils.utils import print_args, load_restore_ckpt_with_optim, load_embedder_ckpt, adjust_learning_rate, data_process, tensor_metric, load_excel, save_checkpoint
+from model.loss import Total_loss
+from model.Embedder import Embedder
+from model.OneRestore import OneRestore
+from torch.utils.data.distributed import DistributedSampler
+from PIL import Image
+
+torch.distributed.init_process_group(backend="nccl")
+local_rank = torch.distributed.get_rank()
+torch.cuda.set_device(local_rank)
+device = torch.device("cuda", local_rank)
+
+
+transform_resize = transforms.Compose([
+ transforms.Resize([224,224]),
+ transforms.ToTensor()
+ ])
+
+def main(args):
+
+
+ print('> Model Initialization...')
+ embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
+ restorer, optimizer, cur_epoch = load_restore_ckpt_with_optim(device, local_rank=local_rank, freeze_model=False, ckpt_name=args.restore_model_path, lr=args.lr)
+ loss = Total_loss(args)
+
+ print('> Loading dataset...')
+ data = Dataset(args.train_input)
+ dataset = DataLoader(dataset=data, batch_size=args.bs,
+ shuffle=False,
+ num_workers=args.num_works,
+ pin_memory=True,drop_last=False,
+ sampler=DistributedSampler(data,shuffle=True))
+
+ print('> Start training...')
+ start_all = time.time()
+ train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device)
+ end_all = time.time()
+ print('Whloe Training Time:' +str(end_all-start_all)+'s.')
+
+def train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device):
+
+ metric = []
+ for epoch in range(cur_epoch, args.epoch):
+ optimizer = adjust_learning_rate(optimizer, epoch, args.adjust_lr)
+ learnrate = optimizer.param_groups[-1]['lr']
+ restorer.train()
+
+ for i, data in enumerate(dataset,0):
+ pos, inp, neg = data_process(data, args, device)
+
+ text_embedding,_,_ = embedder(inp[1],'text_encoder')
+ out = restorer(inp[0], text_embedding)
+
+ restorer.zero_grad()
+ total_loss = loss(inp, pos, neg, out)
+ total_loss.backward()
+ optimizer.step()
+
+ mse = tensor_metric(pos,out, 'MSE', data_range=1)
+ psnr = tensor_metric(pos,out, 'PSNR', data_range=1)
+ ssim = tensor_metric(pos,out, 'SSIM', data_range=1)
+
+ print("[epoch %d][%d/%d] lr :%f Floss: %.4f MSE: %.4f PSNR: %.4f SSIM: %.4f"%(epoch+1, i+1, \
+ len(dataset), learnrate, total_loss.item(), mse, psnr, ssim))
+
+
+ psnr_t1, ssim_t1, psnr_t2, ssim_t2 = test(args, restorer, embedder, device, epoch)
+ metric.append([psnr_t1, ssim_t1, psnr_t2, ssim_t2])
+ print("[epoch %d] Test images PSNR1: %.4f SSIM1: %.4f"%(epoch+1, psnr_t1,ssim_t1))
+
+ load_excel(metric)
+ save_checkpoint({'epoch': epoch + 1,'state_dict': restorer.state_dict(),'optimizer' : optimizer.state_dict()},\
+ args.save_model_path, epoch+1, psnr_t1,ssim_t1,psnr_t2,ssim_t2)
+
+def test(args, restorer, embedder, device, epoch=-1):
+ combine_type = args.degr_type
+ psnr_1, psnr_2, ssim_1, ssim_2 = 0, 0, 0, 0
+ os.makedirs(args.output,exist_ok=True)
+
+ for i in range(len(combine_type)-1):
+ file_list = os.listdir(f'{args.test_input}/{combine_type[i+1]}/')
+ for j in range(len(file_list)):
+ hq = Image.open(f'{args.test_input}/{combine_type[0]}/{file_list[j]}')
+ lq = Image.open(f'{args.test_input}/{combine_type[i+1]}/{file_list[j]}')
+ restorer.eval()
+ with torch.no_grad():
+ lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+ lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+ hq = torch.Tensor((np.array(hq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+
+ starttime = time.time()
+
+ text_embedding_1,_,text_1 = embedder([combine_type[i+1]],'text_encoder')
+ text_embedding_2,_, text_2 = embedder(lq_em,'image_encoder')
+ out_1 = restorer(lq_re, text_embedding_1)
+ if text_1 != text_2:
+ print(text_1, text_2)
+ out_2 = restorer(lq_re, text_embedding_2)
+ else:
+ out_2 = out_1
+
+ endtime1 = time.time()
+
+ imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
+ + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png', range=(0, 1))
+ # due to the vision problem, you can replace above line by
+ # imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
+ # + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png')
+ psnr_1 += tensor_metric(hq, out_1, 'PSNR', data_range=1)
+ ssim_1 += tensor_metric(hq, out_1, 'SSIM', data_range=1)
+ psnr_2 += tensor_metric(hq, out_2, 'PSNR', data_range=1)
+ ssim_2 += tensor_metric(hq, out_2, 'SSIM', data_range=1)
+ print('The ' + file_list[j][:-4] + ' Time:' + str(endtime1 - starttime) + 's.')
+
+ return psnr_1 / (len(file_list)*len(combine_type)), ssim_1 / (len(file_list)*len(combine_type)),\
+ psnr_2 / (len(file_list)*len(combine_type)), ssim_2 / (len(file_list)*len(combine_type))
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description = "OneRestore Training")
+
+ # load model
+ parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
+ parser.add_argument("--restore-model-path", type=str, default = None, help = 'restore model path')
+ parser.add_argument("--save-model-path", type=str, default = "./ckpts/", help = 'restore model path')
+
+ parser.add_argument("--epoch", type=int, default = 300, help = 'epoch number')
+ parser.add_argument("--bs", type=int, default = 4, help = 'batchsize')
+ parser.add_argument("--lr", type=float, default = 1e-4, help = 'learning rate')
+ parser.add_argument("--adjust-lr", type=int, default = 30, help = 'adjust learning rate')
+ parser.add_argument("--num-works", type=int, default = 4, help = 'number works')
+ parser.add_argument("--loss-weight", type=tuple, default = (0.6,0.3,0.1), help = 'loss weights')
+ parser.add_argument("--degr-type", type=list, default = ['clear', 'low', 'haze', 'rain', 'snow',\
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'], help = 'degradation type')
+
+ parser.add_argument("--train-input", type=str, default = "./dataset.h5", help = 'train data')
+ parser.add_argument("--test-input", type=str, default = "./data/CDD-11_test", help = 'test path')
+ parser.add_argument("--output", type=str, default = "./result/", help = 'output path')
+
+ argspar = parser.parse_args()
+
+ print_args(argspar)
+ main(argspar)
diff --git a/train_OneRestore_single-gpu.py b/train_OneRestore_single-gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..6309af76f4135ebca82d5ebca3dea75b5c8cfad5
--- /dev/null
+++ b/train_OneRestore_single-gpu.py
@@ -0,0 +1,140 @@
+import os, time, torch, argparse
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torchvision.utils import save_image as imwrite
+import numpy as np
+from torchvision import transforms
+from makedataset import Dataset
+from utils.utils import print_args, load_restore_ckpt_with_optim, load_embedder_ckpt, adjust_learning_rate, data_process, tensor_metric, load_excel, save_checkpoint
+from model.loss import Total_loss
+
+from PIL import Image
+
+transform_resize = transforms.Compose([
+ transforms.Resize([224,224]),
+ transforms.ToTensor()
+ ])
+
+def main(args):
+
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ print('> Model Initialization...')
+
+ embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
+ restorer, optimizer, cur_epoch = load_restore_ckpt_with_optim(device, freeze_model=False, ckpt_name=args.restore_model_path, lr=args.lr)
+ loss = Total_loss(args)
+
+ print('> Loading dataset...')
+ data = Dataset(args.train_input)
+ dataset = DataLoader(dataset=data, num_workers=args.num_works, batch_size=args.bs, shuffle=True)
+
+ print('> Start training...')
+ start_all = time.time()
+ train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device)
+ end_all = time.time()
+ print('Whloe Training Time:' +str(end_all-start_all)+'s.')
+
+def train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device):
+
+ metric = []
+ for epoch in range(cur_epoch, args.epoch):
+ optimizer = adjust_learning_rate(optimizer, epoch, args.adjust_lr)
+ learnrate = optimizer.param_groups[-1]['lr']
+ restorer.train()
+
+ for i, data in enumerate(dataset,0):
+ pos, inp, neg = data_process(data, args, device)
+
+ text_embedding,_,_ = embedder(inp[1],'text_encoder')
+ out = restorer(inp[0], text_embedding)
+
+ restorer.zero_grad()
+ total_loss = loss(inp, pos, neg, out)
+ total_loss.backward()
+ optimizer.step()
+
+ mse = tensor_metric(pos,out, 'MSE', data_range=1)
+ psnr = tensor_metric(pos,out, 'PSNR', data_range=1)
+ ssim = tensor_metric(pos,out, 'SSIM', data_range=1)
+
+ print("[epoch %d][%d/%d] lr :%f Floss: %.4f MSE: %.4f PSNR: %.4f SSIM: %.4f"%(epoch+1, i+1, \
+ len(dataset), learnrate, total_loss.item(), mse, psnr, ssim))
+
+
+ psnr_t1, ssim_t1, psnr_t2, ssim_t2 = test(args, restorer, embedder, device, epoch)
+ metric.append([psnr_t1, ssim_t1, psnr_t2, ssim_t2])
+ print("[epoch %d] Test images PSNR1: %.4f SSIM1: %.4f"%(epoch+1, psnr_t1,ssim_t1))
+
+ load_excel(metric)
+ save_checkpoint({'epoch': epoch + 1,'state_dict': restorer.state_dict(),'optimizer' : optimizer.state_dict()},\
+ args.save_model_path, epoch+1, psnr_t1,ssim_t1,psnr_t2,ssim_t2)
+
+def test(args, restorer, embedder, device, epoch=-1):
+ combine_type = args.degr_type
+ psnr_1, psnr_2, ssim_1, ssim_2 = 0, 0, 0, 0
+ os.makedirs(args.output,exist_ok=True)
+
+ for i in range(len(combine_type)-1):
+ file_list = os.listdir(f'{args.test_input}/{combine_type[i+1]}/')
+ for j in range(len(file_list)):
+ hq = Image.open(f'{args.test_input}/{combine_type[0]}/{file_list[j]}')
+ lq = Image.open(f'{args.test_input}/{combine_type[i+1]}/{file_list[j]}')
+ restorer.eval()
+ with torch.no_grad():
+ lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+ lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+ hq = torch.Tensor((np.array(hq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
+
+ starttime = time.time()
+
+ text_embedding_1,_,text_1 = embedder([combine_type[i+1]],'text_encoder')
+ text_embedding_2,_, text_2 = embedder(lq_em,'image_encoder')
+ out_1 = restorer(lq_re, text_embedding_1)
+ if text_1 != text_2:
+ print(text_1, text_2)
+ out_2 = restorer(lq_re, text_embedding_2)
+ else:
+ out_2 = out_1
+
+ endtime1 = time.time()
+
+ imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
+ + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png', range=(0, 1))
+ psnr_1 += tensor_metric(hq, out_1, 'PSNR', data_range=1)
+ ssim_1 += tensor_metric(hq, out_1, 'SSIM', data_range=1)
+ psnr_2 += tensor_metric(hq, out_2, 'PSNR', data_range=1)
+ ssim_2 += tensor_metric(hq, out_2, 'SSIM', data_range=1)
+ print('The ' + file_list[j][:-4] + ' Time:' + str(endtime1 - starttime) + 's.')
+
+ return psnr_1 / (len(file_list)*len(combine_type)), ssim_1 / (len(file_list)*len(combine_type)),\
+ psnr_2 / (len(file_list)*len(combine_type)), ssim_2 / (len(file_list)*len(combine_type))
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description = "OneRestore Training")
+
+ # load model
+ parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
+ parser.add_argument("--restore-model-path", type=str, default = None, help = 'restore model path')
+ parser.add_argument("--save-model-path", type=str, default = "./ckpts/", help = 'restore model path')
+
+ parser.add_argument("--epoch", type=int, default = 300, help = 'epoch number')
+ parser.add_argument("--bs", type=int, default = 4, help = 'batchsize')
+ parser.add_argument("--lr", type=float, default = 1e-4, help = 'learning rate')
+ parser.add_argument("--adjust-lr", type=int, default = 30, help = 'adjust learning rate')
+ parser.add_argument("--num-works", type=int, default = 4, help = 'number works')
+ parser.add_argument("--loss-weight", type=tuple, default = (0.6,0.3,0.1), help = 'loss weights')
+ parser.add_argument("--degr-type", type=list, default = ['clear', 'low', 'haze', 'rain', 'snow',\
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'], help = 'degradation type')
+
+ parser.add_argument("--train-input", type=str, default = "./dataset.h5", help = 'train data')
+ parser.add_argument("--test-input", type=str, default = "./data/CDD-11_test", help = 'test path')
+ parser.add_argument("--output", type=str, default = "./result/", help = 'output path')
+
+ argspar = parser.parse_args()
+
+ print_args(argspar)
+ main(argspar)
diff --git a/utils/glove.6B.300d.txt b/utils/glove.6B.300d.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f5f8be23c8b562f31e5c6564fd2b8d6e3691676c
--- /dev/null
+++ b/utils/glove.6B.300d.txt
@@ -0,0 +1,5 @@
+clear -0.081023 -0.29179 0.052021 -0.13324 0.028162 -0.0031446 -0.17156 0.063324 0.16568 -2.1722 -0.14127 0.087891 -0.2298 0.069017 0.21673 0.36556 -0.39979 -0.15506 0.099728 0.202 0.16989 0.14807 0.10938 -0.17141 -0.7258 -0.13189 -0.052768 -0.26383 -0.13189 -0.11408 0.081757 0.14773 -0.24342 0.0076364 -1.0992 0.13661 0.19262 -0.30012 0.031524 0.11439 -0.10854 0.21089 -0.037365 0.23449 0.054638 0.21505 0.023071 0.20918 -0.08606 -0.078589 -0.26945 -0.040802 -0.042601 -0.12093 -0.33614 0.25624 -0.35266 -0.17224 0.31018 0.6426 -0.036072 0.1558 0.26609 0.17298 -0.08158 0.0085636 0.13196 -0.11876 -0.19205 -0.32204 -0.092694 -0.19274 0.0056832 0.17194 0.24011 0.014739 0.091188 0.45903 0.0047753 -0.18136 -0.16434 0.012617 0.42791 0.075318 -0.042848 -0.055952 -0.071895 0.086806 0.078092 0.20169 -0.34189 -0.01975 -0.44579 -0.093254 0.23684 0.098079 -0.0018186 -0.13013 0.054252 -0.68408 0.21378 -0.084742 -0.12383 0.36645 -0.46434 0.56799 0.22341 0.31607 -0.23559 0.033889 0.062509 -0.31468 0.27684 -0.13729 -0.027181 0.17143 -0.35535 0.14426 0.14137 -0.27987 0.051007 0.1689 0.48614 0.43247 -0.31014 -0.2273 -0.17253 0.50221 -0.29023 -0.16833 -0.027586 0.25614 0.096051 0.19145 -0.15576 0.50767 0.0064827 -0.047304 0.47358 -0.029665 -0.095882 0.064574 0.1247 -0.3439 -0.59591 -0.17307 0.30627 0.16351 -0.21709 -0.13142 -0.029781 0.079412 0.36018 -0.068721 0.367 0.26454 0.1306 -0.34602 0.22326 0.22999 0.14122 -0.3084 0.22239 -0.13701 0.24538 0.10902 0.33084 0.052159 -0.54817 0.32921 0.33889 -0.060382 -0.16611 -0.26388 0.13997 -0.15486 -0.05013 -0.089628 -0.0080954 0.13155 -0.019735 0.25758 0.37509 -0.012096 -0.49247 0.13436 -0.21072 -0.13763 0.24047 0.13328 -0.043418 0.0070651 0.30496 -0.11184 0.68017 -0.65417 -0.39198 0.075546 -0.2043 0.041099 0.84586 -0.3361 -0.26385 -0.39417 -0.25468 -0.095349 0.19947 -0.30772 -0.53846 0.18258 -0.091379 -0.27183 0.10918 -0.042102 -0.25614 -0.039694 0.34987 -0.24526 -0.011983 -0.024231 0.62785 -0.16641 0.026109 0.029096 -0.16937 0.25329 -0.12066 0.023087 0.16152 -0.14058 0.044846 0.4533 0.34099 -0.028432 -0.39407 -0.068924 -0.29128 -0.012954 0.048176 -0.090455 -0.0098771 -0.022352 0.091535 -0.084673 -0.43955 -0.25237 0.79719 0.21526 0.0019634 -0.10022 -0.075669 -0.25113 -0.12675 0.12179 0.25892 0.026661 -0.38419 -0.18566 -0.15325 0.44484 -0.088815 0.10119 0.0060884 0.293 -0.415 0.26712 0.033683 -0.42317 0.22025 -0.027351 0.40923 -0.013339 -0.29543 0.37699 -0.019656 -0.082896 -1.5198 0.2961 0.81263 -0.18199 0.59082 0.007938 0.2309 0.23573 0.24941 -0.18754 -0.04029 0.17258 0.1948 0.131 -0.21552 0.016352 0.62256 0.41283 0.40387 -0.062911 -0.093159 -0.078137 -0.30083 -0.035913
+low -0.21751 0.43389 0.149 0.14107 0.2574 -0.12448 0.0047523 0.035596 0.10741 -2.1047 0.17181 -0.15079 -0.044546 -0.090869 -0.43288 -0.13611 -0.0058198 -0.064724 0.23531 -0.36224 -0.21305 -0.075476 0.46786 -0.18465 -0.19746 -0.097471 0.39984 -0.084092 -0.53715 0.27303 -0.087786 0.24297 -0.38444 0.28854 -0.7873 0.089192 -0.26376 -0.16287 0.35911 0.30458 0.24502 0.22553 -0.0031653 0.47358 0.31146 -0.13823 0.075685 -0.10776 0.38329 -0.13762 0.51707 -0.16707 -0.037466 -0.7236 -0.4151 -0.42359 0.14354 0.046639 0.17527 0.48721 0.26708 -0.031042 0.86002 -0.3946 -0.50514 -0.51294 0.58527 0.18819 -0.29543 0.68596 -0.1035 0.22565 0.185 0.058375 0.030999 0.11929 0.12353 0.12873 0.42126 0.14188 -0.050079 -0.2683 0.12126 0.32302 0.27623 0.5414 0.074715 -0.1949 -0.47053 0.02313 0.68686 0.60158 -0.16194 -0.3651 0.41796 -0.22905 0.074734 0.17509 -0.44255 0.3518 -0.40079 -0.28305 0.39133 0.32303 -0.63198 -0.1507 -0.16894 0.17169 0.18894 0.027644 -0.36997 -0.26366 0.36344 -0.049584 0.32724 0.049712 0.051381 -0.058867 -0.2621 -0.50359 -0.21435 -0.25527 0.22161 0.66558 0.2224 0.27607 0.58587 -0.3071 0.24905 0.098802 -0.26459 0.77839 0.014585 0.86936 0.2329 -0.0027986 -0.087016 0.10863 0.18987 0.54552 0.24903 0.059293 0.30362 -0.028582 -0.6569 0.1206 -0.055416 -0.093077 -0.0012132 -0.15009 0.11192 -0.62139 -0.035773 0.1165 0.36541 0.55984 -0.19964 -0.065579 0.097118 -0.1672 0.13677 -0.95276 -0.25994 0.064799 -0.042161 0.12046 0.12391 0.0017478 0.29533 0.40176 0.057528 0.57864 -0.9973 0.13805 -0.30689 0.11015 -0.35402 -0.13434 -0.24479 0.50355 -0.18675 -0.22337 0.29573 0.21612 -0.068496 -0.60643 0.79013 -0.26975 -0.15492 0.70849 0.21372 0.62962 -0.0056421 0.53597 -0.54259 -0.34726 -0.29945 -0.51895 0.28471 -0.14973 0.54188 0.53535 -0.11233 0.19291 -0.24707 0.058424 -0.5473 -0.06426 0.47187 0.11149 0.28313 -0.23876 -0.10552 -0.051705 -0.28853 -0.13702 0.040562 -0.032269 0.10368 -0.29381 0.33416 0.038269 0.029697 -0.48604 -0.26334 0.28942 -0.0093944 0.13942 -0.29043 0.27332 0.16614 -0.028973 -0.32829 -0.034614 -0.0012628 0.062871 -0.000894 0.22467 0.16005 0.23141 -0.19918 0.16465 0.15247 0.29742 -1.0225 0.056188 0.91529 -0.47809 -0.24204 -0.3158 0.21033 -0.13616 0.10777 -0.26815 -0.44804 -0.12696 -0.43468 0.17849 -0.48101 0.026114 0.057368 0.26052 -0.030488 0.051275 -0.36344 0.11878 0.2279 -0.086855 -0.01455 0.070256 -0.16753 0.61449 -0.27428 -0.17901 -0.36261 0.093134 -1.5724 0.47192 -0.52493 -0.27512 -0.37945 0.29588 0.020506 0.08707 0.057053 0.37167 -0.056446 -0.38735 -0.31246 0.028304 -0.058202 0.067263 -0.58761 0.074556 0.49917 0.45134 -0.51433 -0.60996 0.076835 -0.078086
+haze -0.0061289 -0.2702 0.16559 -0.29621 -0.66216 -0.1756 0.46686 1.0362 -0.20692 -0.36097 0.98615 0.32297 -0.55094 -0.36163 -0.27046 0.052225 -0.10079 0.22536 -0.095491 0.17188 0.058372 0.083556 -0.28255 0.12623 -0.0094164 -0.028727 -0.20589 -0.3932 -0.2935 -0.36104 1.0595 0.14423 -0.311 -0.20573 0.11827 -0.0048368 -0.8324 -0.10389 0.34491 0.34006 0.10354 0.11593 0.47379 -0.1042 0.38523 -0.57589 0.027253 -0.44913 -0.52822 -0.44094 0.71219 -0.12278 0.034288 -0.6935 -0.57852 0.33917 0.35018 -0.30193 0.55504 0.085603 -0.21189 -0.51958 -0.17589 -0.13369 0.2976 -0.26048 0.068146 0.62144 0.3416 -0.54399 -0.23937 -0.34802 -0.31469 -0.59554 -0.25011 -0.11644 0.19993 -0.1636 0.24289 -0.0022965 0.3064 -0.26188 0.27166 0.1962 0.37527 -0.22408 0.52979 0.59141 0.035196 0.10632 -0.28318 0.18766 -0.12253 0.41932 -0.64713 0.26068 0.67209 -0.23333 0.030945 -0.15135 0.61662 -0.0025061 -0.58374 0.51866 -0.89244 1.0056 0.15919 0.29183 -0.059984 0.10701 -0.32101 -1.0921 -0.050394 -0.074584 0.56258 -0.5915 0.048547 0.085668 -0.39964 -0.40997 0.093632 -0.22538 -0.83102 -0.051418 -0.31192 0.36056 -0.028854 -0.046907 0.09394 0.012504 0.34555 0.56564 0.48111 0.092143 0.82492 -0.20086 -0.27718 0.9004 0.38921 0.028667 0.78904 0.44698 -0.26892 0.073712 -0.73296 -0.46286 0.53386 0.53514 0.04207 -0.11448 0.27771 0.080703 -0.017482 0.43225 0.047742 -0.095399 -0.063173 -0.36341 0.2948 0.15311 -0.55934 -0.88294 0.62005 -0.23936 0.51953 -0.49463 0.41669 0.61169 -0.20471 -0.0056962 -0.29331 0.46269 0.084808 -0.049355 -0.64697 -0.85777 0.34718 -0.16176 0.14756 -0.65658 -0.54259 -0.13124 -0.88851 0.070637 -0.84926 -0.69345 0.4024 -0.5683 -0.68142 -0.1402 -0.36857 0.36013 -0.49769 -0.17478 0.77214 -0.23962 0.32951 1.0984 -0.00011441 0.9649 -0.13312 0.64326 -0.037091 0.35672 0.025156 0.046782 0.19764 -0.22757 -0.39887 -0.3045 -0.45283 -0.0045182 0.032546 -0.076483 0.72189 -0.038917 1.0621 -0.55688 0.56429 0.11264 0.40465 -0.53146 0.16851 0.69236 -0.24456 0.038704 0.69151 0.16591 -0.43451 0.14115 0.84069 0.29081 -0.31053 -0.6849 -0.27188 -0.32813 0.57882 0.13779 0.36621 -0.45935 0.27899 -0.32315 -0.5743 0.19837 0.0046648 0.18459 0.43369 0.22359 0.16652 -0.081114 -0.54539 -1.0103 -0.14539 0.12021 0.078636 -0.26667 -0.65403 0.4096 0.07257 0.036639 0.21757 0.25738 0.51675 -0.031326 -0.3869 0.012763 -0.45692 0.13828 -0.48614 -0.53757 0.50268 0.47865 -0.049528 -0.032281 -0.4486 0.036258 -0.12295 -0.46811 -0.019014 0.035839 -0.55749 0.018281 -0.88963 -0.024676 -0.19482 -0.19364 0.0069875 0.12679 -0.37379 -0.34094 -0.051568 0.55404 -0.29656 0.26045 0.50872 -0.37399 0.20334 0.70298 -0.3271 -0.24116
+rain -0.52618 -0.54041 -0.89537 -0.35598 -0.74356 -0.66838 0.26326 0.89254 0.14362 -0.34904 0.25866 -0.11143 -0.52035 0.1436 -0.075728 -0.84569 -0.28762 0.049872 0.39234 0.52551 -0.39244 -0.2822 -0.097458 -0.12929 -0.38623 0.17261 0.7574 -0.29868 -0.691 -0.36639 0.63951 0.25255 -0.22299 0.16387 -0.83199 -0.30276 -0.32411 -0.36789 -0.073673 0.54726 0.14785 0.26259 0.086208 -0.033827 0.044403 -0.2135 0.3761 0.33816 -0.36696 -0.2096 0.025934 0.47679 0.23046 -0.44333 -0.65379 0.85762 0.62861 -0.70343 1.1284 0.2497 -0.34459 0.17005 0.27826 0.01167 -0.44087 -0.12649 0.31811 0.073688 -0.17127 -0.023486 0.34294 0.18888 -0.15694 -0.37975 -0.58313 -0.45624 -0.5968 0.09743 -0.50593 -0.64092 0.083647 0.38474 -0.15071 0.55042 -0.68742 0.14893 -0.039046 -0.19582 0.61498 -0.066786 0.63395 -0.4659 0.44123 -0.55136 -0.17711 0.97118 0.26321 -0.035901 -0.11096 -0.11161 0.353 1.026 -0.2605 -0.12231 0.31695 0.35807 0.2526 0.21803 -0.47766 -0.13033 -0.36929 -0.88388 -0.1249 0.27972 0.017521 0.19048 0.38647 -0.10236 0.26691 -0.66637 -0.66046 -0.48598 -0.5029 0.59602 -0.23975 -0.054244 0.71177 0.097479 0.18964 0.60496 -0.2421 1.261 0.5195 0.12978 0.28374 0.1499 -0.073072 -0.064345 0.041775 0.20712 -0.13972 0.021692 -0.45101 -0.077633 -0.58888 -0.0062811 0.50587 0.63067 -0.096216 -0.45549 -0.10162 -0.74026 -0.45125 0.16204 0.34589 0.2203 0.73482 -0.72055 0.019937 0.50934 -0.045864 -1.0167 0.4202 0.29336 0.057842 0.19622 0.71137 0.44455 -0.11329 -0.23249 0.3283 0.6458 -0.032498 0.58903 0.067438 -0.21519 0.24967 -0.047893 -0.12095 0.20468 -0.010392 -0.10827 0.5248 -0.013868 -0.40703 -0.2761 0.61498 -0.12118 -0.70097 -0.76415 -0.37243 0.3 -0.32852 -0.13877 0.23339 -0.58504 0.54768 -0.090521 0.30928 -0.19777 0.68883 0.043808 -0.012833 0.25696 0.017598 -0.11323 -0.76201 0.42972 -0.22032 -0.43818 -0.57085 0.23867 -0.098037 -0.4015 0.27659 -0.51578 -0.28637 -0.37785 0.83469 0.10563 1.1508 -0.67165 0.095388 -0.070545 0.039198 0.17726 0.44885 -0.045378 0.22337 -0.24957 0.93144 -0.16601 -0.095582 -0.60227 0.20068 -0.10264 -0.62696 0.048702 0.34737 -0.10634 -0.35068 0.11719 -0.79712 -0.32956 -0.60446 -0.0049038 -0.3351 -0.060065 -0.3063 -0.15462 -0.099521 -0.1788 0.098109 -0.59477 0.53245 -0.15388 0.063044 -0.47686 0.26712 -0.064799 0.2029 -0.093498 -0.44456 0.4692 -0.13718 0.035772 -0.74958 -0.51603 0.47025 -0.65103 0.027106 0.31463 -0.51519 -0.09912 -0.30605 0.2127 -1.6502 -0.34658 -0.19282 0.036578 -0.33871 0.21323 0.54172 -0.17543 -0.60187 -0.14679 0.20983 -0.084584 0.070885 -0.21752 -0.12642 0.030381 0.075461 0.86541 0.30098 0.22916 0.049217 -0.21204 0.32909 -0.021816
+snow -0.6961 -0.3339 -0.66542 -0.16459 -0.70283 0.053264 0.57508 1.1246 -0.41143 -0.93335 -0.397 -0.13949 -0.21725 0.49383 -0.16481 -0.43673 -0.39998 -0.14702 0.5828 0.73123 -0.16808 0.050093 0.20341 0.093283 -0.18944 -0.0092796 0.0064213 -0.5586 0.079708 0.034177 0.503 -0.084123 -0.15241 0.042398 -0.95865 0.13482 0.10695 0.22212 0.16383 0.081416 -0.61437 0.60299 0.53843 0.33915 -0.060046 -0.12329 0.30417 0.067838 -0.058329 -0.24791 -0.28177 0.32273 -0.12639 -0.40664 -0.42578 0.71366 0.18676 -0.49576 0.56635 0.39411 -0.11876 0.62798 0.50193 -0.38534 -0.32333 -0.29613 -0.1984 0.082042 -0.63666 -0.25177 0.070225 0.23886 -0.35341 -0.30615 -0.7898 -0.014515 -0.096662 0.27064 0.37095 -0.3916 0.15589 0.40176 -0.12316 -0.0069311 -0.17538 0.29317 -0.035662 -0.062503 -0.11821 -0.26708 0.33433 -0.41039 -0.44941 -0.058539 -0.5973 -0.060833 0.014623 0.031391 0.041093 0.21223 0.54304 0.51444 -0.2447 -0.034937 -0.61583 0.24116 0.93612 0.29663 -0.01733 0.39864 -0.399 -0.69927 0.010899 0.044804 0.096444 0.20555 0.37109 0.13219 0.29942 -0.28494 -0.071103 -0.45338 -0.22126 -0.31673 -0.10643 0.040453 -0.15324 0.33191 0.27801 -0.25143 -0.41784 1.1352 0.18709 0.57932 0.14912 0.42731 -0.81353 0.35546 0.10287 -0.10858 0.13692 0.11451 -0.68607 -0.17115 -0.52708 0.28953 0.5147 0.25549 -0.23139 -0.44275 0.42679 -0.41475 0.041182 -0.2664 0.60967 0.03783 0.27371 -0.5267 0.12029 0.5208 0.59519 -1.1315 0.19505 -0.2528 0.34636 0.82065 0.63271 0.091682 0.38433 -0.81108 0.18232 0.19068 -0.13031 0.21336 0.074454 -0.094498 0.47594 -0.31026 -0.11718 0.092891 0.22067 -0.16721 0.71703 0.30143 -0.40609 -0.16231 0.31315 -0.59325 -0.53404 -0.1087 -0.23026 0.36507 0.30648 -0.75576 -0.20767 -0.46966 -0.21035 0.0091924 0.5057 0.45564 0.84145 -0.19412 0.23964 0.85852 0.05229 -0.0011899 -0.29387 0.044187 -0.23886 0.19207 -0.0079459 -0.25773 0.31145 -0.47615 -0.00056431 -0.8941 -0.38667 -0.37907 0.52821 -0.45513 0.53567 0.13216 0.39741 -0.4904 0.24118 -0.11714 0.27007 0.15184 0.42316 -0.39708 0.13827 -0.27638 0.29908 -0.76008 0.061752 -0.4452 -0.5132 0.12124 0.15792 -0.57067 -0.68793 -0.33873 -0.43291 -0.46817 -0.84667 -0.65852 -0.59116 -0.043406 -0.013031 0.11246 -0.35374 0.3923 0.1172 -0.56268 0.83477 -0.34675 0.054568 -0.48494 0.12108 -0.15504 -0.047008 -0.2665 0.024593 0.70123 0.21284 -0.077796 0.050835 0.3865 0.37534 -0.48749 -0.013739 0.57852 -0.90425 -0.0062806 -0.28674 -0.017749 -1.0189 -0.71371 -0.36557 -0.73412 -0.027371 -0.071396 0.64792 -0.057281 -0.2512 0.039567 0.076976 0.34572 0.34606 -0.38323 -0.074011 -0.14153 -0.03109 0.53137 -0.35708 -0.28263 0.098663 0.17693 -0.39297 0.27708
\ No newline at end of file
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..616900d3064e7793e80623094413210a88992b17
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,232 @@
+import numpy as np
+import torch
+import os
+from torch.autograd import Variable
+from skimage.metrics import peak_signal_noise_ratio as compare_psnr
+from skimage.metrics import mean_squared_error as compare_mse
+from skimage.metrics import structural_similarity as compare_ssim
+import pandas as pd
+
+from model.OneRestore import OneRestore
+from model.Embedder import Embedder
+
+def load_embedder_ckpt(device, freeze_model=False, ckpt_name=None,
+ combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow']):
+ if ckpt_name != None:
+ if torch.cuda.is_available():
+ model_info = torch.load(ckpt_name)
+ else:
+ model_info = torch.load(ckpt_name, map_location=torch.device('cpu'))
+
+ print('==> loading existing Embedder model:', ckpt_name)
+ model = Embedder(combine_type)
+ model.load_state_dict(model_info)
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
+
+ else:
+ print('==> Initialize Embedder model.')
+ model = Embedder(combine_type)
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
+
+ if freeze_model:
+ freeze(model)
+
+ return model
+
+def load_restore_ckpt(device, freeze_model=False, ckpt_name=None):
+ if ckpt_name != None:
+ if torch.cuda.is_available():
+ model_info = torch.load(ckpt_name)
+ else:
+ model_info = torch.load(ckpt_name, map_location=torch.device('cpu'))
+ print('==> loading existing OneRestore model:', ckpt_name)
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
+ model.load_state_dict(model_info)
+ else:
+ print('==> Initialize OneRestore model.')
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
+ model = torch.nn.DataParallel(model).to("cuda" if torch.cuda.is_available() else "cpu")
+
+ if freeze_model:
+ freeze(model)
+ total = sum([param.nelement() for param in model.parameters()])
+ print("Number of OneRestore parameter: %.2fM" % (total/1e6))
+
+ return model
+
+def load_restore_ckpt_with_optim(device, local_rank=None, freeze_model=False, ckpt_name=None, lr=None):
+ if ckpt_name != None:
+ if torch.cuda.is_available():
+ model_info = torch.load(ckpt_name)
+ else:
+ model_info = torch.load(ckpt_name, map_location=torch.device('cpu'))
+
+ print('==> loading existing OneRestore model:', ckpt_name)
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr) if lr != None else None
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else model
+
+ if local_rank != None:
+ model.load_state_dict(model_info['state_dict'])
+ else:
+ weights_dict = {}
+ for k, v in model_info['state_dict'].items():
+ new_k = k.replace('module.', '') if 'module' in k else k
+ weights_dict[new_k] = v
+ model.load_state_dict(weights_dict)
+ optimizer = torch.optim.Adam(model.parameters())
+ optimizer.load_state_dict(model_info['optimizer'])
+ cur_epoch = model_info['epoch']
+ else:
+ print('==> Initialize OneRestore model.')
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else torch.nn.DataParallel(model)
+ cur_epoch = 0
+
+ if freeze_model:
+ freeze(model)
+ total = sum([param.nelement() for param in model.parameters()])
+ print("Number of OneRestore parameter: %.2fM" % (total/1e6))
+
+ return model, optimizer, cur_epoch
+
+def load_embedder_ckpt_with_optim(device, args, combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow']):
+ print('Init embedder')
+ # seed
+ if args.seed == -1:
+ args.seed = np.random.randint(1, 10000)
+ seed = args.seed
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ print('Training embedder seed:', seed)
+
+ # embedder model
+ embedder = Embedder(combine_type).to("cuda" if torch.cuda.is_available() else "cpu")
+
+ if args.pre_weight == '':
+ optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr)
+ cur_epoch = 1
+ else:
+ try:
+ embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}')
+ if torch.cuda.is_available():
+ embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}')
+ else:
+ embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}', map_location=torch.device('cpu'))
+ embedder.load_state_dict(embedder_info['state_dict'])
+ optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr)
+ optimizer.load_state_dict(embedder_info['optimizer'])
+ cur_epoch = embedder_info['epoch'] + 1
+ except:
+ print('Pre-trained model loading error!')
+ return embedder, optimizer, cur_epoch, device
+
+def freeze_text_embedder(m):
+ """Freezes module m.
+ """
+ m.eval()
+ for name, para in m.named_parameters():
+ if name == 'embedder.weight' or name == 'mlp.0.weight' or name == 'mlp.0.bias':
+ print(name)
+ para.requires_grad = False
+ para.grad = None
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+def data_process(data, args, device):
+ combine_type = args.degr_type
+ b,n,c,w,h = data.size()
+
+ pos_data = data[:,0,:,:,:]
+
+ inp_data = torch.zeros((b,c,w,h))
+ inp_class = []
+
+ neg_data = torch.zeros((b,n-2,c,w,h))
+
+ index = np.random.randint(1, n, (b))
+ for i in range(b):
+ k = 0
+ for j in range(n):
+ if j == 0:
+ continue
+ elif index[i] == j:
+ inp_class.append(combine_type[index[i]])
+ inp_data[i, :, :, :] = data[i, index[i], :, :,:]
+ else:
+ neg_data[i,k,:,:,:] = data[i, j, :, :,:]
+ k=k+1
+ return pos_data.to("cuda" if torch.cuda.is_available() else "cpu"), [inp_data.to("cuda" if torch.cuda.is_available() else "cpu"), inp_class], neg_data.to("cuda" if torch.cuda.is_available() else "cpu")
+
+def print_args(argspar):
+ print("\nParameter Print")
+ for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
+ print('\t{}: {}'.format(p, v))
+ print('\n')
+
+def adjust_learning_rate(optimizer, epoch, lr_update_freq):
+ if not epoch % lr_update_freq and epoch:
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = param_group['lr'] /2
+ return optimizer
+
+
+def tensor_metric(img, imclean, model, data_range=1):
+
+ img_cpu = img.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1)
+ imgclean = imclean.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1)
+
+ SUM = 0
+ for i in range(img_cpu.shape[0]):
+
+ if model == 'PSNR':
+ SUM += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :],data_range=data_range)
+ elif model == 'MSE':
+ SUM += compare_mse(imgclean[i, :, :, :], img_cpu[i, :, :, :])
+ elif model == 'SSIM':
+ SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, multichannel = True)
+ # due to the skimage vision problem, you can replace above line by
+ # SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, channel_axis=-1)
+ else:
+ print('Model False!')
+
+ return SUM/img_cpu.shape[0]
+
+def save_checkpoint(stateF, checkpoint, epoch, psnr_t1,ssim_t1,psnr_t2,ssim_t2, filename='model.tar'):
+ torch.save(stateF, checkpoint + 'OneRestore_model_%d_%.4f_%.4f_%.4f_%.4f.tar'%(epoch,psnr_t1,ssim_t1,psnr_t2,ssim_t2))
+
+def load_excel(x):
+ data1 = pd.DataFrame(x)
+
+ writer = pd.ExcelWriter('./mertic_result.xlsx')
+ data1.to_excel(writer, 'PSNR-SSIM', float_format='%.5f')
+ # writer.save()
+ writer.close()
+
+def freeze(m):
+ """Freezes module m.
+ """
+ m.eval()
+ for p in m.parameters():
+ p.requires_grad = False
+ p.grad = None
diff --git a/utils/utils_data.py b/utils/utils_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c578f57b48de228c3e79ec131ecdeb290e2c82f
--- /dev/null
+++ b/utils/utils_data.py
@@ -0,0 +1,92 @@
+import torch, os
+from PIL import Image
+import numpy as np
+import torchvision.transforms as transforms
+import torch.utils.data as data
+from einops import rearrange
+
+class ImageLoader:
+ def __init__(self, root):
+ self.img_dir = root
+
+ def __call__(self, img):
+ file = f'{self.img_dir}/{img}'
+ img = Image.open(file).convert('RGB')
+ return img
+
+def imagenet_transform(phase):
+
+ if phase == 'train':
+ transform = transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor()
+ ])
+
+ elif phase == 'test':
+ transform = transforms.Compose([
+ transforms.Resize([224,224]),
+ transforms.ToTensor()
+ ])
+
+ return transform
+
+class Dataset_embedding(data.Dataset):
+ def __init__(self, cfg_data, phase='train'):
+
+ self.transform = imagenet_transform(phase)
+ self.type_name = cfg_data.type_name
+ self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))}
+
+ if phase == 'train':
+ self.loader = ImageLoader(cfg_data.train_dir)
+ name = os.listdir(f'{cfg_data.train_dir}/{self.type_name[0]}')
+ self.data = []
+ for i in range(len(self.type_name)):
+ for j in range(len(name)):
+ self.data.append([self.type_name[i], name[j]])
+ elif phase == 'test':
+ self.loader = ImageLoader(cfg_data.test_dir)
+ name = os.listdir(f'{cfg_data.test_dir}/{self.type_name[0]}')
+ self.data = []
+ for i in range(1, len(self.type_name)):
+ for j in range(len(name)):
+ self.data.append([self.type_name[i], name[j]])
+ print(f'The amount of {phase} data is {len(self.data)}')
+
+ def __getitem__(self, index):
+
+ type_name, image_name = self.data[index]
+ scene = self.type2idx[type_name]
+ image = self.transform(self.loader(f'{type_name}/{image_name}'))
+
+ return (scene, image)
+
+ def __len__(self):
+ return len(self.data)
+
+def init_embedding_data(cfg_em, phase):
+ if phase == 'train':
+ train_dataset = Dataset_embedding(cfg_em, 'train')
+ test_dataset = Dataset_embedding(cfg_em, 'test')
+ train_loader = data.DataLoader(train_dataset,
+ batch_size=cfg_em.batch,
+ shuffle=True,
+ num_workers=cfg_em.num_workers,
+ pin_memory=True)
+ test_loader = data.DataLoader(test_dataset,
+ batch_size=cfg_em.batch,
+ shuffle=False,
+ num_workers=cfg_em.num_workers,
+ pin_memory=True)
+ print(len(train_dataset),len(test_dataset))
+
+ elif phase == 'inference':
+ test_dataset = Dataset_embedding(cfg_em, 'test')
+ test_loader = data.DataLoader(test_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=cfg_em.num_workers,
+ pin_memory=True)
+
+ return train_loader, test_loader
\ No newline at end of file
diff --git a/utils/utils_word_embedding.py b/utils/utils_word_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbf039b2feab048f09f21476e138250577246ac0
--- /dev/null
+++ b/utils/utils_word_embedding.py
@@ -0,0 +1,157 @@
+import torch
+import numpy as np
+import fasttext.util
+from gensim import models
+
+def load_word_embeddings(emb_file, vocab):
+ embeds = {}
+ for line in open(emb_file, 'rb'):
+ line = line.decode().strip().split(' ')
+ wvec = torch.FloatTensor(list(map(float, line[1:])))
+ embeds[line[0]] = wvec
+
+ # for zappos (should account for everything)
+ custom_map = {
+ 'Faux.Fur':'fake_fur', 'Faux.Leather':'fake_leather', 'Full.grain.leather':'thick_leather',
+ 'Hair.Calf':'hair_leather', 'Patent.Leather':'shiny_leather', 'Nubuck':'grainy_leather',
+ 'Boots.Ankle':'ankle_boots', 'Boots.Knee.High':'knee_high_boots', 'Boots.Mid-Calf':'midcalf_boots',
+ 'Shoes.Boat.Shoes':'boat_shoes', 'Shoes.Clogs.and.Mules':'clogs_shoes', 'Shoes.Flats':'flats_shoes',
+ 'Shoes.Heels':'heels', 'Shoes.Loafers':'loafers', 'Shoes.Oxfords':'oxford_shoes',
+ 'Shoes.Sneakers.and.Athletic.Shoes':'sneakers'}
+ custom_map_vaw = {
+ 'selfie': 'photo'
+ }
+
+ E = []
+ for k in vocab:
+ if k in custom_map:
+ print(f'Change {k} to {custom_map[k]}')
+ k = custom_map[k]
+ k = k.lower()
+ if '_' in k:
+ toks = k.split('_')
+ emb_tmp = torch.zeros(300).float()
+ for tok in toks:
+ if tok in custom_map_vaw:
+ tok = custom_map_vaw[tok]
+ emb_tmp += embeds[tok]
+ emb_tmp /= len(toks)
+ E.append(emb_tmp)
+ else:
+ E.append(embeds[k])
+
+ embeds = torch.stack(E)
+ print ('Loaded embeddings from file %s' % emb_file, embeds.size())
+
+ return embeds
+
+def load_fasttext_embeddings(emb_file,vocab):
+ custom_map = {
+ 'Faux.Fur': 'fake fur',
+ 'Faux.Leather': 'fake leather',
+ 'Full.grain.leather': 'thick leather',
+ 'Hair.Calf': 'hairy leather',
+ 'Patent.Leather': 'shiny leather',
+ 'Boots.Ankle': 'ankle boots',
+ 'Boots.Knee.High': 'kneehigh boots',
+ 'Boots.Mid-Calf': 'midcalf boots',
+ 'Shoes.Boat.Shoes': 'boatshoes',
+ 'Shoes.Clogs.and.Mules': 'clogs shoes',
+ 'Shoes.Flats': 'flats shoes',
+ 'Shoes.Heels': 'heels',
+ 'Shoes.Loafers': 'loafers',
+ 'Shoes.Oxfords': 'oxford shoes',
+ 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers',
+ 'traffic_light': 'traficlight',
+ 'trash_can': 'trashcan',
+ 'dry-erase_board' : 'dry_erase_board',
+ 'black_and_white' : 'black_white',
+ 'eiffel_tower' : 'tower'
+ }
+ vocab_lower = [v.lower() for v in vocab]
+ vocab = []
+ for current in vocab_lower:
+ if current in custom_map:
+ vocab.append(custom_map[current])
+ else:
+ vocab.append(current)
+
+
+ ft = fasttext.load_model(emb_file) #DATA_FOLDER+'/fast/cc.en.300.bin')
+ embeds = []
+ for k in vocab:
+ if '_' in k:
+ ks = k.split('_')
+ emb = np.stack([ft.get_word_vector(it) for it in ks]).mean(axis=0)
+ else:
+ emb = ft.get_word_vector(k)
+ embeds.append(emb)
+
+ embeds = torch.Tensor(np.stack(embeds))
+ print('Fasttext Embeddings loaded, total embeddings: {}'.format(embeds.size()))
+ return embeds
+
+def load_word2vec_embeddings(emb_file,vocab):
+ # vocab = [v.lower() for v in vocab]
+
+
+ model = models.KeyedVectors.load_word2vec_format(emb_file,binary=True)
+ #DATA_FOLDER+'/w2v/GoogleNews-vectors-negative300.bin', binary=True)
+
+ custom_map = {
+ 'Faux.Fur': 'fake_fur',
+ 'Faux.Leather': 'fake_leather',
+ 'Full.grain.leather': 'thick_leather',
+ 'Hair.Calf': 'hair_leather',
+ 'Patent.Leather': 'shiny_leather',
+ 'Boots.Ankle': 'ankle_boots',
+ 'Boots.Knee.High': 'knee_high_boots',
+ 'Boots.Mid-Calf': 'midcalf_boots',
+ 'Shoes.Boat.Shoes': 'boat_shoes',
+ 'Shoes.Clogs.and.Mules': 'clogs_shoes',
+ 'Shoes.Flats': 'flats_shoes',
+ 'Shoes.Heels': 'heels',
+ 'Shoes.Loafers': 'loafers',
+ 'Shoes.Oxfords': 'oxford_shoes',
+ 'Shoes.Sneakers.and.Athletic.Shoes': 'sneakers',
+ 'traffic_light': 'traffic_light',
+ 'trash_can': 'trashcan',
+ 'dry-erase_board' : 'dry_erase_board',
+ 'black_and_white' : 'black_white',
+ 'eiffel_tower' : 'tower'
+ }
+
+ embeds = []
+ for k in vocab:
+ if k in custom_map:
+ k = custom_map[k]
+ if '_' in k and k not in model:
+ ks = k.split('_')
+ emb = np.stack([model[it] for it in ks]).mean(axis=0)
+ else:
+ emb = model[k]
+ embeds.append(emb)
+ embeds = torch.Tensor(np.stack(embeds))
+ print('Word2Vec Embeddings loaded, total embeddings: {}'.format(embeds.size()))
+ return embeds
+
+
+
+def initialize_wordembedding_matrix(name, vocab):
+ """
+ Args:
+ - name: hyphen separated word embedding names: 'glove-word2vec-conceptnet'.
+ - vocab: list of attributes/objects.
+ """
+ wordembs = name.split('+')
+ result = None
+
+ for wordemb in wordembs:
+ if wordemb == 'glove':
+ wordemb_ = load_word_embeddings(f'./utils/glove.6B.300d.txt', vocab)
+ if result is None:
+ result = wordemb_
+ else:
+ result = torch.cat((result, wordemb_), dim=1)
+ dim = 300 * len(wordembs)
+ return result, dim