diff --git a/.github/workflows/push-huggingface.yml b/.github/workflows/push-huggingface.yml new file mode 100644 index 0000000000000000000000000000000000000000..317156dfc9b2d7cbdf2c0ca5ac12d4e3b070108b --- /dev/null +++ b/.github/workflows/push-huggingface.yml @@ -0,0 +1,22 @@ +name: Push to Hugging Face + +on: + push: + branches: [ "master" ] + +jobs: + push: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Push repository to Hugging Face + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + git config --global user.email "phuochungus@gmail.com" + git config --global user.name "HungNP" + git remote add space https://huggingface.co/spaces/phuochungus/PyCIL_Stanford_Car + git checkout -b main + git reset $(git commit-tree HEAD^{tree} -m "New single commit message") + git push --force https://phuochungus:$HF_TOKEN@huggingface.co/spaces/phuochungus/PyCIL_Stanford_Car main + git push --force https://phuochungus:$HF_TOKEN@huggingface.co/spaces/DevSecOpAI/PyCIL main diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e746b5623d4ac019731bcf2e3796edf99ea87b15 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +data/ +__pycache__/ +logs/ +.env diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..045b2a92de917713413b8c59bd165d7f10f8fb00 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +FROM python:3.8.5 + +RUN useradd -m -u 1000 user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH +WORKDIR $HOME + +RUN apt-get update && apt-get install -y unzip + +RUN pip install --no-cache-dir --upgrade pip +RUN pip install Cython +RUN pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + +COPY --chown=user requirements.txt requirements.txt + +RUN pip install -r requirements.txt + +COPY --chown=user download_dataset.sh download_dataset.sh + +RUN chmod +x download_dataset.sh + +RUN ./download_dataset.sh + +COPY --chown=user . . + +RUN chmod +x install_awscli.sh && ./install_awscli.sh + +RUN chmod +x entrypoint.sh upload_s3.sh simple_train.sh train_from_working.sh + +ENTRYPOINT [ "./entrypoint.sh" ] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1946f39efccb4fceae1752928617ea8fd99552d2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,43 @@ +MIT License + +Copyright (c) 2020 Changhong Zhong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +MIT License + +Copyright (c) 2021 Fu-Yun Wang. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6ffe0fb3989e2c1f8c52b9c25c7860703743b373 --- /dev/null +++ b/README.md @@ -0,0 +1,248 @@ +--- +title: Pycil +emoji: 🍳 +colorFrom: red +colorTo: red +sdk: docker +pinned: false +--- +# PyCIL: A Python Toolbox for Class-Incremental Learning + +--- + +

+ Introduction • + Methods Reproduced • + Reproduced Results • + How To Use • + License • + Acknowledgments • + Contact +

+ +
+ +
+ +--- + + + +
+ +[![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](https://github.com/yaoyao-liu/class-incremental-learning/blob/master/LICENSE)[![Python](https://img.shields.io/badge/python-3.8-blue.svg?style=flat-square&logo=python&color=3776AB&logoColor=3776AB)](https://www.python.org/) [![PyTorch](https://img.shields.io/badge/pytorch-1.8-%237732a8?style=flat-square&logo=PyTorch&color=EE4C2C)](https://pytorch.org/) [![method](https://img.shields.io/badge/Reproduced-20-success)]() [![CIL](https://img.shields.io/badge/ClassIncrementalLearning-SOTA-success??style=for-the-badge&logo=appveyor)](https://paperswithcode.com/task/incremental-learning) +![visitors](https://visitor-badge.laobi.icu/badge?page_id=LAMDA.PyCIL&left_color=green&right_color=red) + +
+ +Welcome to PyCIL, perhaps the toolbox for class-incremental learning with the **most** implemented methods. This is the code repository for "PyCIL: A Python Toolbox for Class-Incremental Learning" [[paper]](https://arxiv.org/abs/2112.12533) in PyTorch. If you use any content of this repo for your work, please cite the following bib entry: + + @article{zhou2023pycil, + author = {Da-Wei Zhou and Fu-Yun Wang and Han-Jia Ye and De-Chuan Zhan}, + title = {PyCIL: a Python toolbox for class-incremental learning}, + journal = {SCIENCE CHINA Information Sciences}, + year = {2023}, + volume = {66}, + number = {9}, + pages = {197101-}, + doi = {https://doi.org/10.1007/s11432-022-3600-y} + } + + @article{zhou2023class, + author = {Zhou, Da-Wei and Wang, Qi-Wei and Qi, Zhi-Hong and Ye, Han-Jia and Zhan, De-Chuan and Liu, Ziwei}, + title = {Deep Class-Incremental Learning: A Survey}, + journal = {arXiv preprint arXiv:2302.03648}, + year = {2023} + } + + +## What's New +- [2024-03]🌟 Check out our [latest work](https://arxiv.org/abs/2403.12030) on pre-trained model-based class-incremental learning! +- [2024-01]🌟 Check out our [latest survey](https://arxiv.org/abs/2401.16386) on pre-trained model-based continual learning! +- [2023-09]🌟 We have released [PILOT](https://github.com/sun-hailong/LAMDA-PILOT) toolbox for class-incremental learning with pre-trained models. Have a try! +- [2023-07]🌟 Add [MEMO](https://openreview.net/forum?id=S07feAlQHgM), [BEEF](https://openreview.net/forum?id=iP77_axu0h3), and [SimpleCIL](https://arxiv.org/abs/2303.07338). State-of-the-art methods of 2023! +- [2023-05]🌟 Check out our recent work about [class-incremental learning with vision-language models](https://arxiv.org/abs/2305.19270)! +- [2023-02]🌟 Check out our [rigorous and unified survey](https://arxiv.org/abs/2302.03648) about class-incremental learning, which introduces some memory-agnostic measures with holistic evaluations from multiple aspects! +- [2022-12]🌟 Add FrTrIL, PASS, IL2A, and SSRE. +- [2022-10]🌟 PyCIL has been published in [SCIENCE CHINA Information Sciences](https://link.springer.com/article/10.1007/s11432-022-3600-y). Check out the [official introduction](https://mp.weixin.qq.com/s/h1qu2LpdvjeHAPLOnG478A)! +- [2022-08]🌟 Add RMM. +- [2022-07]🌟 Add [FOSTER](https://arxiv.org/abs/2204.04662). State-of-the-art method with a single backbone! +- [2021-12]🌟 **Call For Feedback**: We add a section to introduce awesome works using PyCIL. If you are using PyCIL to publish your work in top-tier conferences/journals, feel free to [contact us](mailto:zhoudw@lamda.nju.edu.cn) for details! + +## Introduction + +Traditional machine learning systems are deployed under the closed-world setting, which requires the entire training data before the offline training process. However, real-world applications often face the incoming new classes, and a model should incorporate them continually. The learning paradigm is called Class-Incremental Learning (CIL). We propose a Python toolbox that implements several key algorithms for class-incremental learning to ease the burden of researchers in the machine learning community. The toolbox contains implementations of a number of founding works of CIL, such as EWC and iCaRL, but also provides current state-of-the-art algorithms that can be used for conducting novel fundamental research. This toolbox, named PyCIL for Python Class-Incremental Learning, is open source with an MIT license. + +For more information about incremental learning, you can refer to these reading materials: +- A brief introduction (in Chinese) about CIL is available [here](https://zhuanlan.zhihu.com/p/490308909). +- A PyTorch Tutorial to Class-Incremental Learning (with explicit codes and detailed explanations) is available [here](https://github.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning). + +## Methods Reproduced + +- `FineTune`: Baseline method which simply updates parameters on new tasks. +- `EWC`: Overcoming catastrophic forgetting in neural networks. PNAS2017 [[paper](https://arxiv.org/abs/1612.00796)] +- `LwF`: Learning without Forgetting. ECCV2016 [[paper](https://arxiv.org/abs/1606.09282)] +- `Replay`: Baseline method with exemplar replay. +- `GEM`: Gradient Episodic Memory for Continual Learning. NIPS2017 [[paper](https://arxiv.org/abs/1706.08840)] +- `iCaRL`: Incremental Classifier and Representation Learning. CVPR2017 [[paper](https://arxiv.org/abs/1611.07725)] +- `BiC`: Large Scale Incremental Learning. CVPR2019 [[paper](https://arxiv.org/abs/1905.13260)] +- `WA`: Maintaining Discrimination and Fairness in Class Incremental Learning. CVPR2020 [[paper](https://arxiv.org/abs/1911.07053)] +- `PODNet`: PODNet: Pooled Outputs Distillation for Small-Tasks Incremental Learning. ECCV2020 [[paper](https://arxiv.org/abs/2004.13513)] +- `DER`: DER: Dynamically Expandable Representation for Class Incremental Learning. CVPR2021 [[paper](https://arxiv.org/abs/2103.16788)] +- `PASS`: Prototype Augmentation and Self-Supervision for Incremental Learning. CVPR2021 [[paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhu_Prototype_Augmentation_and_Self-Supervision_for_Incremental_Learning_CVPR_2021_paper.pdf)] +- `RMM`: RMM: Reinforced Memory Management for Class-Incremental Learning. NeurIPS2021 [[paper](https://proceedings.neurips.cc/paper/2021/hash/1cbcaa5abbb6b70f378a3a03d0c26386-Abstract.html)] +- `IL2A`: Class-Incremental Learning via Dual Augmentation. NeurIPS2021 [[paper](https://proceedings.neurips.cc/paper/2021/file/77ee3bc58ce560b86c2b59363281e914-Paper.pdf)] +- `SSRE`: Self-Sustaining Representation Expansion for Non-Exemplar Class-Incremental Learning. CVPR2022 [[paper](https://arxiv.org/abs/2203.06359)] +- `FeTrIL`: Feature Translation for Exemplar-Free Class-Incremental Learning. WACV2023 [[paper](https://arxiv.org/abs/2211.13131)] +- `Coil`: Co-Transport for Class-Incremental Learning. ACM MM2021 [[paper](https://arxiv.org/abs/2107.12654)] +- `FOSTER`: Feature Boosting and Compression for Class-incremental Learning. ECCV 2022 [[paper](https://arxiv.org/abs/2204.04662)] +- `MEMO`: A Model or 603 Exemplars: Towards Memory-Efficient Class-Incremental Learning. ICLR 2023 Spotlight [[paper](https://openreview.net/forum?id=S07feAlQHgM)] +- `BEEF`: BEEF: Bi-Compatible Class-Incremental Learning via Energy-Based Expansion and Fusion. ICLR 2023 [[paper](https://openreview.net/forum?id=iP77_axu0h3)] +- `SimpleCIL`: Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need. arXiv 2023 [[paper](https://arxiv.org/abs/2303.07338)] + +> Intended authors are welcome to contact us to reproduce your methods in our repo. Feel free to merge your algorithm into PyCIL if you are using our codebase! + +## Reproduced Results + +#### CIFAR-100 + +
+ +
+ + +#### ImageNet-100 + +
+ +
+ +#### ImageNet-100 (Top-5 Accuracy) + +
+ +
+ +> More experimental details and results can be found in our [survey](https://arxiv.org/abs/2302.03648). + +## How To Use + +### Clone + +Clone this GitHub repository: + +``` +git clone https://github.com/G-U-N/PyCIL.git +cd PyCIL +``` + +### Dependencies + +1. [torch 1.81](https://github.com/pytorch/pytorch) +2. [torchvision 0.6.0](https://github.com/pytorch/vision) +3. [tqdm](https://github.com/tqdm/tqdm) +4. [numpy](https://github.com/numpy/numpy) +5. [scipy](https://github.com/scipy/scipy) +6. [quadprog](https://github.com/quadprog/quadprog) +7. [POT](https://github.com/PythonOT/POT) + +### Run experiment + +1. Edit the `[MODEL NAME].json` file for global settings. +2. Edit the hyperparameters in the corresponding `[MODEL NAME].py` file (e.g., `models/icarl.py`). +3. Run: + +```bash +python main.py --config=./exps/[MODEL NAME].json +``` + +where [MODEL NAME] should be chosen from `finetune`, `ewc`, `lwf`, `replay`, `gem`, `icarl`, `bic`, `wa`, `podnet`, `der`, etc. + +4. `hyper-parameters` + +When using PyCIL, you can edit the global parameters and algorithm-specific hyper-parameter in the corresponding json file. + +These parameters include: + +- **memory-size**: The total exemplar number in the incremental learning process. Assuming there are $K$ classes at the current stage, the model will preserve $\left[\frac{memory-size}{K}\right]$ exemplar per class. +- **init-cls**: The number of classes in the first incremental stage. Since there are different settings in CIL with a different number of classes in the first stage, our framework enables different choices to define the initial stage. +- **increment**: The number of classes in each incremental stage $i$, $i$ > 1. By default, the number of classes per incremental stage is equivalent per stage. +- **convnet-type**: The backbone network for the incremental model. According to the benchmark setting, `ResNet32` is utilized for `CIFAR100`, and `ResNet18` is used for `ImageNet`. +- **seed**: The random seed adopted for shuffling the class order. According to the benchmark setting, it is set to 1993 by default. + +Other parameters in terms of model optimization, e.g., batch size, optimization epoch, learning rate, learning rate decay, weight decay, milestone, and temperature, can be modified in the corresponding Python file. + +### Datasets + +We have implemented the pre-processing of `CIFAR100`, `imagenet100,` and `imagenet1000`. When training on `CIFAR100`, this framework will automatically download it. When training on `imagenet100/1000`, you should specify the folder of your dataset in `utils/data.py`. + +```python + def download_data(self): + assert 0,"You should specify the folder of your dataset" + train_dir = '[DATA-PATH]/train/' + test_dir = '[DATA-PATH]/val/' +``` +[Here](https://drive.google.com/drive/folders/1RBrPGrZzd1bHU5YG8PjdfwpHANZR_lhJ?usp=sharing) is the file list of ImageNet100 (or say ImageNet-Sub). + +## Awesome Papers using PyCIL + +### Our Papers +- Expandable Subspace Ensemble for Pre-Trained Model-Based Class-Incremental Learning (**CVPR 2024**) [[paper](https://arxiv.org/abs/2403.12030 )] [[code](https://github.com/sun-hailong/CVPR24-Ease)] + +- Continual Learning with Pre-Trained Models: A Survey (**arXiv 2024**) [[paper](https://arxiv.org/abs/2401.16386)] [[code](https://github.com/sun-hailong/LAMDA-PILOT)] + +- Deep Class-Incremental Learning: A Survey (**arXiv 2023**) [[paper](https://arxiv.org/abs/2302.03648)] [[code](https://github.com/zhoudw-zdw/CIL_Survey/)] + +- Learning without Forgetting for Vision-Language Models (**arXiv 2023**) [[paper](https://arxiv.org/abs/2305.19270)] + +- Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need (**arXiv 2023**) [[paper](https://arxiv.org/abs/2303.07338)] [[code](https://github.com/zhoudw-zdw/RevisitingCIL)] + +- PILOT: A Pre-Trained Model-Based Continual Learning Toolbox (**arXiv 2023**) [[paper](https://arxiv.org/abs/2309.07117)] [[code](https://github.com/sun-hailong/LAMDA-PILOT)] + +- Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration (**NeurIPS 2023**)[[paper](https://arxiv.org/abs/2312.05229)] [[Code](https://github.com/wangkiw/TEEN)] + +- BEEF: Bi-Compatible Class-Incremental Learning via Energy-Based Expansion and Fusion (**ICLR 2023**) [[paper](https://openreview.net/forum?id=iP77_axu0h3)] [[code](https://github.com/G-U-N/ICLR23-BEEF/)] + +- A model or 603 exemplars: Towards memory-efficient class-incremental learning (**ICLR 2023**) [[paper](https://arxiv.org/abs/2205.13218)] [[code](https://github.com/wangkiw/ICLR23-MEMO/)] + +- Few-shot class-incremental learning by sampling multi-phase tasks (**TPAMI 2022**) [[paper](https://arxiv.org/pdf/2203.17030.pdf)] [[code](https://github.com/zhoudw-zdw/TPAMI-Limit)] + +- Foster: Feature Boosting and Compression for Class-incremental Learning (**ECCV 2022**) [[paper](https://arxiv.org/abs/2204.04662)] [[code](https://github.com/G-U-N/ECCV22-FOSTER/)] + +- Forward compatible few-shot class-incremental learning (**CVPR 2022**) [[paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhou_Forward_Compatible_Few-Shot_Class-Incremental_Learning_CVPR_2022_paper.pdf)] [[code](https://github.com/zhoudw-zdw/CVPR22-Fact)] + +- Co-Transport for Class-Incremental Learning (**ACM MM 2021**) [[paper](https://arxiv.org/abs/2107.12654)] [[code](https://github.com/zhoudw-zdw/MM21-Coil)] + +### Other Awesome Works + +- Towards Realistic Evaluation of Industrial Continual Learning Scenarios with an Emphasis on Energy Consumption and Computational Footprint (**ICCV 2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Chavan_Towards_Realistic_Evaluation_of_Industrial_Continual_Learning_Scenarios_with_an_ICCV_2023_paper.pdf)][[code](https://github.com/Vivek9Chavan/RECIL)] + +- Dynamic Residual Classifier for Class Incremental Learning (**ICCV 2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Chen_Dynamic_Residual_Classifier_for_Class_Incremental_Learning_ICCV_2023_paper.pdf)][[code](https://github.com/chen-xw/DRC-CIL)] + +- S-Prompts Learning with Pre-trained Transformers: An Occam's Razor for Domain Incremental Learning (**NeurIPS 2022**) [[paper](https://openreview.net/forum?id=ZVe_WeMold)] [[code](https://github.com/iamwangyabin/S-Prompts)] + + +## License + +Please check the MIT [license](./LICENSE) that is listed in this repository. + +## Acknowledgments + +We thank the following repos providing helpful components/functions in our work. + +- [Continual-Learning-Reproduce](https://github.com/zhchuu/continual-learning-reproduce) +- [GEM](https://github.com/hursung1/GradientEpisodicMemory) +- [FACIL](https://github.com/mmasana/FACIL) + +The training flow and data configurations are based on Continual-Learning-Reproduce. The original information of the repo is available in the base branch. + + +## Contact + +If there are any questions, please feel free to propose new features by opening an issue or contact with the author: **Da-Wei Zhou**([zhoudw@lamda.nju.edu.cn](mailto:zhoudw@lamda.nju.edu.cn)) and **Fu-Yun Wang**(wangfuyun@smail.nju.edu.cn). Enjoy the code. + + +## Star History 🚀 + +[![Star History Chart](https://api.star-history.com/svg?repos=G-U-N/PyCIL&type=Date)](https://star-history.com/#G-U-N/PyCIL&Date) + diff --git a/convs/__init__.py b/convs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/convs/cifar_resnet.py b/convs/cifar_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a2668dd45ccd05b6a881b544b6008abd8d8b58af --- /dev/null +++ b/convs/cifar_resnet.py @@ -0,0 +1,207 @@ +''' +Reference: +https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py +''' +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DownsampleA(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleA, self).__init__() + assert stride == 2 + self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) + + def forward(self, x): + x = self.avg(x) + return torch.cat((x, x.mul(0)), 1) + + +class DownsampleB(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleB, self).__init__() + self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) + self.bn = nn.BatchNorm2d(nOut) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class DownsampleC(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleC, self).__init__() + assert stride != 1 or nIn != nOut + self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) + + def forward(self, x): + x = self.conv(x) + return x + + +class DownsampleD(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleD, self).__init__() + assert stride == 2 + self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) + self.bn = nn.BatchNorm2d(nOut) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ResNetBasicblock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(ResNetBasicblock, self).__init__() + + self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(planes) + + self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + + self.downsample = downsample + + def forward(self, x): + residual = x + + basicblock = self.conv_a(x) + basicblock = self.bn_a(basicblock) + basicblock = F.relu(basicblock, inplace=True) + + basicblock = self.conv_b(basicblock) + basicblock = self.bn_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(x) + + return F.relu(residual + basicblock, inplace=True) + + +class CifarResNet(nn.Module): + """ + ResNet optimized for the Cifar Dataset, as specified in + https://arxiv.org/abs/1512.03385.pdf + """ + + def __init__(self, block, depth, channels=3): + super(CifarResNet, self).__init__() + + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' + layer_blocks = (depth - 2) // 6 + + self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_1 = nn.BatchNorm2d(16) + + self.inplanes = 16 + self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) + self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) + self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) + self.avgpool = nn.AvgPool2d(8) + self.out_dim = 64 * block.expansion + self.fc = nn.Linear(64*block.expansion, 10) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv_1_3x3(x) # [bs, 16, 32, 32] + x = F.relu(self.bn_1(x), inplace=True) + + x_1 = self.stage_1(x) # [bs, 16, 32, 32] + x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] + x_3 = self.stage_3(x_2) # [bs, 64, 8, 8] + + pooled = self.avgpool(x_3) # [bs, 64, 1, 1] + features = pooled.view(pooled.size(0), -1) # [bs, 64] + + return { + 'fmaps': [x_1, x_2, x_3], + 'features': features + } + + @property + def last_conv(self): + return self.stage_3[-1].conv_b + + +def resnet20mnist(): + """Constructs a ResNet-20 model for MNIST.""" + model = CifarResNet(ResNetBasicblock, 20, 1) + return model + + +def resnet32mnist(): + """Constructs a ResNet-32 model for MNIST.""" + model = CifarResNet(ResNetBasicblock, 32, 1) + return model + + +def resnet20(): + """Constructs a ResNet-20 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 20) + return model + + +def resnet32(): + """Constructs a ResNet-32 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 32) + return model + + +def resnet44(): + """Constructs a ResNet-44 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 44) + return model + + +def resnet56(): + """Constructs a ResNet-56 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 56) + return model + + +def resnet110(): + """Constructs a ResNet-110 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 110) + return model + +# for auc +def resnet14(): + model = CifarResNet(ResNetBasicblock, 14) + return model + +def resnet26(): + model = CifarResNet(ResNetBasicblock, 26) + return model \ No newline at end of file diff --git a/convs/conv_cifar.py b/convs/conv_cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..2c337270b92312b0dc7ed37b2f7c937345f40696 --- /dev/null +++ b/convs/conv_cifar.py @@ -0,0 +1,77 @@ +''' +For MEMO implementations of CIFAR-ConvNet +Reference: +https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_cifar.py +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + +# for cifar +def conv_block(in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.MaxPool2d(2) + ) + +class ConvNet2(nn.Module): + def __init__(self, x_dim=3, hid_dim=64, z_dim=64): + super().__init__() + self.out_dim = 64 + self.avgpool = nn.AvgPool2d(8) + self.encoder = nn.Sequential( + conv_block(x_dim, hid_dim), + conv_block(hid_dim, z_dim), + ) + + def forward(self, x): + x = self.encoder(x) + x = self.avgpool(x) + features = x.view(x.shape[0], -1) + return { + "features":features + } + +class GeneralizedConvNet2(nn.Module): + def __init__(self, x_dim=3, hid_dim=64, z_dim=64): + super().__init__() + self.encoder = nn.Sequential( + conv_block(x_dim, hid_dim), + ) + + def forward(self, x): + base_features = self.encoder(x) + return base_features + +class SpecializedConvNet2(nn.Module): + def __init__(self,hid_dim=64,z_dim=64): + super().__init__() + self.feature_dim = 64 + self.avgpool = nn.AvgPool2d(8) + self.AdaptiveBlock = conv_block(hid_dim,z_dim) + + def forward(self,x): + base_features = self.AdaptiveBlock(x) + pooled = self.avgpool(base_features) + features = pooled.view(pooled.size(0),-1) + return features + +def conv2(): + return ConvNet2() + +def get_conv_a2fc(): + basenet = GeneralizedConvNet2() + adaptivenet = SpecializedConvNet2() + return basenet,adaptivenet + +if __name__ == '__main__': + a, b = get_conv_a2fc() + _base = sum(p.numel() for p in a.parameters()) + _adap = sum(p.numel() for p in b.parameters()) + print(f"conv :{_base+_adap}") + + conv2 = conv2() + conv2_sum = sum(p.numel() for p in conv2.parameters()) + print(f"conv2 :{conv2_sum}") \ No newline at end of file diff --git a/convs/conv_imagenet.py b/convs/conv_imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..59793b2184b938250d0a55fff1786d1fc724ed73 --- /dev/null +++ b/convs/conv_imagenet.py @@ -0,0 +1,82 @@ +''' +For MEMO implementations of ImageNet-ConvNet +Reference: +https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py +''' +import torch.nn as nn +import torch + +# for imagenet +def first_block(in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.MaxPool2d(2) + ) + +def conv_block(in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.MaxPool2d(2) + ) + +class ConvNet(nn.Module): + def __init__(self, x_dim=3, hid_dim=128, z_dim=512): + super().__init__() + self.block1 = first_block(x_dim, hid_dim) + self.block2 = conv_block(hid_dim, hid_dim) + self.block3 = conv_block(hid_dim, hid_dim) + self.block4 = conv_block(hid_dim, z_dim) + self.avgpool = nn.AvgPool2d(7) + self.out_dim = 512 + + def forward(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + + x = self.avgpool(x) + features = x.view(x.shape[0], -1) + + return { + "features": features + } + +class GeneralizedConvNet(nn.Module): + def __init__(self, x_dim=3, hid_dim=128, z_dim=512): + super().__init__() + self.block1 = first_block(x_dim, hid_dim) + self.block2 = conv_block(hid_dim, hid_dim) + self.block3 = conv_block(hid_dim, hid_dim) + + def forward(self, x): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + return x + +class SpecializedConvNet(nn.Module): + def __init__(self, hid_dim=128,z_dim=512): + super().__init__() + self.block4 = conv_block(hid_dim, z_dim) + self.avgpool = nn.AvgPool2d(7) + self.feature_dim = 512 + + def forward(self, x): + x = self.block4(x) + x = self.avgpool(x) + features = x.view(x.shape[0], -1) + return features + +def conv4(): + model = ConvNet() + return model + +def conv_a2fc_imagenet(): + _base = GeneralizedConvNet() + _adaptive_net = SpecializedConvNet() + return _base, _adaptive_net \ No newline at end of file diff --git a/convs/linears.py b/convs/linears.py new file mode 100644 index 0000000000000000000000000000000000000000..f2eb0a316b68d7f520a6b1ff41613d3387fd49bc --- /dev/null +++ b/convs/linears.py @@ -0,0 +1,167 @@ +''' +Reference: +https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py +''' +import math +import torch +from torch import nn +from torch.nn import functional as F + + +class SimpleLinear(nn.Module): + ''' + Reference: + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py + ''' + def __init__(self, in_features, out_features, bias=True): + super(SimpleLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') + nn.init.constant_(self.bias, 0) + + def forward(self, input): + return {'logits': F.linear(input, self.weight, self.bias)} + + +class CosineLinear(nn.Module): + def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True): + super(CosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features * nb_proxy + self.nb_proxy = nb_proxy + self.to_reduce = to_reduce + self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + else: + self.register_parameter('sigma', None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.sigma is not None: + self.sigma.data.fill_(1) + + def forward(self, input): + out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) + + if self.to_reduce: + # Reduce_proxy + out = reduce_proxies(out, self.nb_proxy) + + if self.sigma is not None: + out = self.sigma * out + + return {'logits': out} + + +class SplitCosineLinear(nn.Module): + def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True): + super(SplitCosineLinear, self).__init__() + self.in_features = in_features + self.out_features = (out_features1 + out_features2) * nb_proxy + self.nb_proxy = nb_proxy + self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False) + self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + self.sigma.data.fill_(1) + else: + self.register_parameter('sigma', None) + + def forward(self, x): + out1 = self.fc1(x) + out2 = self.fc2(x) + + out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel + + # Reduce_proxy + out = reduce_proxies(out, self.nb_proxy) + + if self.sigma is not None: + out = self.sigma * out + + return { + 'old_scores': reduce_proxies(out1['logits'], self.nb_proxy), + 'new_scores': reduce_proxies(out2['logits'], self.nb_proxy), + 'logits': out + } + + +def reduce_proxies(out, nb_proxy): + if nb_proxy == 1: + return out + bs = out.shape[0] + nb_classes = out.shape[1] / nb_proxy + assert nb_classes.is_integer(), 'Shape error' + nb_classes = int(nb_classes) + + simi_per_class = out.view(bs, nb_classes, nb_proxy) + attentions = F.softmax(simi_per_class, dim=-1) + + return (attentions * simi_per_class).sum(-1) + + +''' +class CosineLinear(nn.Module): + def __init__(self, in_features, out_features, sigma=True): + super(CosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + else: + self.register_parameter('sigma', None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.sigma is not None: + self.sigma.data.fill_(1) + + def forward(self, input): + out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) + if self.sigma is not None: + out = self.sigma * out + return {'logits': out} + + +class SplitCosineLinear(nn.Module): + def __init__(self, in_features, out_features1, out_features2, sigma=True): + super(SplitCosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features1 + out_features2 + self.fc1 = CosineLinear(in_features, out_features1, False) + self.fc2 = CosineLinear(in_features, out_features2, False) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + self.sigma.data.fill_(1) + else: + self.register_parameter('sigma', None) + + def forward(self, x): + out1 = self.fc1(x) + out2 = self.fc2(x) + + out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel + if self.sigma is not None: + out = self.sigma * out + + return { + 'old_scores': out1['logits'], + 'new_scores': out2['logits'], + 'logits': out + } +''' diff --git a/convs/memo_cifar_resnet.py b/convs/memo_cifar_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3d585519f60bd2d421d277aed4affcb16e9d4af5 --- /dev/null +++ b/convs/memo_cifar_resnet.py @@ -0,0 +1,164 @@ +''' +For MEMO implementations of CIFAR-ResNet +Reference: +https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py +''' +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DownsampleA(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleA, self).__init__() + assert stride == 2 + self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) + + def forward(self, x): + x = self.avg(x) + return torch.cat((x, x.mul(0)), 1) + +class ResNetBasicblock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(ResNetBasicblock, self).__init__() + + self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(planes) + + self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + + self.downsample = downsample + + def forward(self, x): + residual = x + + basicblock = self.conv_a(x) + basicblock = self.bn_a(basicblock) + basicblock = F.relu(basicblock, inplace=True) + + basicblock = self.conv_b(basicblock) + basicblock = self.bn_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(x) + + return F.relu(residual + basicblock, inplace=True) + + + +class GeneralizedResNet_cifar(nn.Module): + def __init__(self, block, depth, channels=3): + super(GeneralizedResNet_cifar, self).__init__() + assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' + layer_blocks = (depth - 2) // 6 + self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_1 = nn.BatchNorm2d(16) + + self.inplanes = 16 + self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) + self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) + + self.out_dim = 64 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv_1_3x3(x) # [bs, 16, 32, 32] + x = F.relu(self.bn_1(x), inplace=True) + + x_1 = self.stage_1(x) # [bs, 16, 32, 32] + x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] + return x_2 + +class SpecializedResNet_cifar(nn.Module): + def __init__(self, block, depth, inplanes=32, feature_dim=64): + super(SpecializedResNet_cifar, self).__init__() + self.inplanes = inplanes + self.feature_dim = feature_dim + layer_blocks = (depth - 2) // 6 + self.final_stage = self._make_layer(block, 64, layer_blocks, 2) + self.avgpool = nn.AvgPool2d(8) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=2): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, base_feature_map): + final_feature_map = self.final_stage(base_feature_map) + pooled = self.avgpool(final_feature_map) + features = pooled.view(pooled.size(0), -1) #bs x 64 + return features + +#For cifar & MEMO +def get_resnet8_a2fc(): + basenet = GeneralizedResNet_cifar(ResNetBasicblock,8) + adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,8) + return basenet,adaptivenet + +def get_resnet14_a2fc(): + basenet = GeneralizedResNet_cifar(ResNetBasicblock,14) + adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,14) + return basenet,adaptivenet + +def get_resnet20_a2fc(): + basenet = GeneralizedResNet_cifar(ResNetBasicblock,20) + adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,20) + return basenet,adaptivenet + +def get_resnet26_a2fc(): + basenet = GeneralizedResNet_cifar(ResNetBasicblock,26) + adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,26) + return basenet,adaptivenet + +def get_resnet32_a2fc(): + basenet = GeneralizedResNet_cifar(ResNetBasicblock,32) + adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,32) + return basenet,adaptivenet + + diff --git a/convs/memo_resnet.py b/convs/memo_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..507b0bd60e35528b93a5820af373b1a1a06d2aff --- /dev/null +++ b/convs/memo_resnet.py @@ -0,0 +1,322 @@ +''' +For MEMO implementations of ImageNet-ResNet +Reference: +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +''' +import torch +import torch.nn as nn +try: + from torchvision.models.utils import load_state_dict_from_url +except: + from torch.hub import load_state_dict_from_url + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class GeneralizedResNet_imagenet(nn.Module): + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(GeneralizedResNet_imagenet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, # stride=2 -> stride=1 for cifar + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Removed in _forward_impl for cifar + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.out_dim = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + return nn.Sequential(*layers) + def _forward_impl(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x_1 = self.layer1(x) + x_2 = self.layer2(x_1) + x_3 = self.layer3(x_2) + return x_3 + + def forward(self, x): + return self._forward_impl(x) + +class SpecializedResNet_imagenet(nn.Module): + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(SpecializedResNet_imagenet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.feature_dim = 512 * block.expansion + self.inplanes = 256 * block.expansion + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.out_dim = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self,x): + x_4 = self.layer4(x) # [bs, 512, 4, 4] + pooled = self.avgpool(x_4) # [bs, 512, 1, 1] + features = torch.flatten(pooled, 1) # [bs, 512] + return features + +def get_resnet10_imagenet(): + basenet = GeneralizedResNet_imagenet(BasicBlock,[1, 1, 1, 1]) + adaptivenet = SpecializedResNet_imagenet(BasicBlock, [1, 1, 1, 1]) + return basenet,adaptivenet + +def get_resnet18_imagenet(): + basenet = GeneralizedResNet_imagenet(BasicBlock,[2, 2, 2, 2]) + adaptivenet = SpecializedResNet_imagenet(BasicBlock, [2, 2, 2, 2]) + return basenet,adaptivenet + +def get_resnet26_imagenet(): + basenet = GeneralizedResNet_imagenet(Bottleneck,[2, 2, 2, 2]) + adaptivenet = SpecializedResNet_imagenet(Bottleneck, [2, 2, 2, 2]) + return basenet,adaptivenet + +def get_resnet34_imagenet(): + basenet = GeneralizedResNet_imagenet(BasicBlock,[3, 4, 6, 3]) + adaptivenet = SpecializedResNet_imagenet(BasicBlock, [3, 4, 6, 3]) + return basenet,adaptivenet + +def get_resnet50_imagenet(): + basenet = GeneralizedResNet_imagenet(Bottleneck,[3, 4, 6, 3]) + adaptivenet = SpecializedResNet_imagenet(Bottleneck, [3, 4, 6, 3]) + return basenet,adaptivenet + + +if __name__ == '__main__': + model2imagenet = 3*224*224 + + a, b = get_resnet10_imagenet() + _base = sum(p.numel() for p in a.parameters()) + _adap = sum(p.numel() for p in b.parameters()) + print(f"resnet10 #params:{_base+_adap}") + + a, b = get_resnet18_imagenet() + _base = sum(p.numel() for p in a.parameters()) + _adap = sum(p.numel() for p in b.parameters()) + print(f"resnet18 #params:{_base+_adap}") + + a, b = get_resnet26_imagenet() + _base = sum(p.numel() for p in a.parameters()) + _adap = sum(p.numel() for p in b.parameters()) + print(f"resnet26 #params:{_base+_adap}") + + a, b = get_resnet34_imagenet() + _base = sum(p.numel() for p in a.parameters()) + _adap = sum(p.numel() for p in b.parameters()) + print(f"resnet34 #params:{_base+_adap}") + + a, b = get_resnet50_imagenet() + _base = sum(p.numel() for p in a.parameters()) + _adap = sum(p.numel() for p in b.parameters()) + print(f"resnet50 #params:{_base+_adap}") \ No newline at end of file diff --git a/convs/modified_represnet.py b/convs/modified_represnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b451cadcd15b25420dabfc6d662d6f440879c533 --- /dev/null +++ b/convs/modified_represnet.py @@ -0,0 +1,177 @@ +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +__all__ = ['ResNet', 'resnet18_rep', 'resnet34_rep' ] + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=True) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True) + +class conv_block(nn.Module): + + def __init__(self, in_planes, planes, mode, stride=1): + super(conv_block, self).__init__() + self.conv = conv3x3(in_planes, planes, stride) + self.mode = mode + if mode == 'parallel_adapters': + self.adapter = conv1x1(in_planes, planes, stride) + + + def re_init_conv(self): + nn.init.kaiming_normal_(self.adapter.weight, mode='fan_out', nonlinearity='relu') + return + def forward(self, x): + y = self.conv(x) + if self.mode == 'parallel_adapters': + y = y + self.adapter(x) + + return y + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, mode, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv_block(inplanes, planes, mode, stride) + self.norm1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv_block(planes, planes, mode) + self.norm2 = nn.BatchNorm2d(planes) + self.mode = mode + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.norm2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=100, args = None): + self.inplanes = 64 + super(ResNet, self).__init__() + assert args is not None + self.mode = args["mode"] + + if 'cifar' in args["dataset"]: + self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) + print("use cifar") + elif 'imagenet' in args["dataset"] or 'stanfordcar' in args["dataset"]: + if args["init_cls"] == args["increment"]: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + else: + # Following PODNET implmentation + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.feature = nn.AvgPool2d(4, stride=1) + self.out_dim = 512 + + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=True), + ) + layers = [] + layers.append(block(self.inplanes, planes, self.mode, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, self.mode)) + + return nn.Sequential(*layers) + + def switch(self, mode='normal'): + for name, module in self.named_modules(): + if hasattr(module, 'mode'): + module.mode = mode + def re_init_params(self): + for name, module in self.named_modules(): + if hasattr(module, 're_init_conv'): + module.re_init_conv() + def forward(self, x): + x = self.conv1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + dim = x.size()[-1] + pool = nn.AvgPool2d(dim, stride=1) + x = pool(x) + x = x.view(x.size(0), -1) + return {"features": x} + + +def resnet18_rep(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet18']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet34_rep(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet34']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model \ No newline at end of file diff --git a/convs/resnet.py b/convs/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4d205be57c859f3db2843eebbbecb4d8328e7bd1 --- /dev/null +++ b/convs/resnet.py @@ -0,0 +1,395 @@ +''' +Reference: +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +''' +import torch +import torch.nn as nn +try: + from torchvision.models.utils import load_state_dict_from_url +except: + from torch.hub import load_state_dict_from_url + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + + + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None,args=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + + assert args is not None, "you should pass args to resnet" + if 'cifar' in args["dataset"]: + if args["model_name"] == "memo": + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True)) + elif 'imagenet' in args["dataset"] or 'stanfordcar' in args['dataset'] or 'general_dataset' in args['dataset']: + if args["init_cls"] == args["increment"]: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.out_dim = 512 * block.expansion + # self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) # [bs, 64, 32, 32] + + x_1 = self.layer1(x) # [bs, 128, 32, 32] + x_2 = self.layer2(x_1) # [bs, 256, 16, 16] + x_3 = self.layer3(x_2) # [bs, 512, 8, 8] + x_4 = self.layer4(x_3) # [bs, 512, 4, 4] + + pooled = self.avgpool(x_4) # [bs, 512, 1, 1] + features = torch.flatten(pooled, 1) # [bs, 512] + # x = self.fc(x) + + return { + 'fmaps': [x_1, x_2, x_3, x_4], + 'features': features + } + + def forward(self, x): + return self._forward_impl(x) + + @property + def last_conv(self): + if hasattr(self.layer4[-1], 'conv3'): + return self.layer4[-1].conv3 + else: + return self.layer4[-1].conv2 + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + +def resnet10(pretrained=False, progress=True, **kwargs): + """ + For MEMO implementations of ResNet-10 + """ + return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress, + **kwargs) + +def resnet26(pretrained=False, progress=True, **kwargs): + """ + For MEMO implementations of ResNet-26 + """ + return _resnet('resnet26', Bottleneck, [2, 2, 2, 2], pretrained, progress, + **kwargs) + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) diff --git a/convs/resnet_cbam.py b/convs/resnet_cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..240c430fb6b103cc9885f479bafaaad18f691cd1 --- /dev/null +++ b/convs/resnet_cbam.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam', + 'resnet152_cbam'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) + self.relu1 = nn.ReLU() + self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) + max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) + out = avg_out + max_out + return self.sigmoid(out) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + self.ca = ChannelAttention(planes) + self.sa = SpatialAttention() + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.ca = ChannelAttention(planes * 4) + self.sa = SpatialAttention() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.ca(out) * out + out = self.sa(out) * out + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=100, args=None): + self.inplanes = 64 + super(ResNet, self).__init__() + assert args is not None, "you should pass args to resnet" + if 'cifar' in args["dataset"]: + self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) + elif 'imagenet' in args["dataset"] or 'stanfordcar' in args['dataset']: + if args["init_cls"] == args["increment"]: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.feature = nn.AvgPool2d(4, stride=1) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + self.out_dim = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + dim = x.size()[-1] + pool = nn.AvgPool2d(dim, stride=1) + x = pool(x) + x = x.view(x.size(0), -1) + return {"features": x} + +def resnet18_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet18']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet34_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet34']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet50_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet50']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet101_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet101']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet152_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet152']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model \ No newline at end of file diff --git a/convs/ucir_cifar_resnet.py b/convs/ucir_cifar_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9e71b742f56925d3a717228a8450c43e59dca39f --- /dev/null +++ b/convs/ucir_cifar_resnet.py @@ -0,0 +1,204 @@ +''' +Reference: +https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py +https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_resnet_cifar.py +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +# from convs.modified_linear import CosineLinear + + +class DownsampleA(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleA, self).__init__() + assert stride == 2 + self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) + + def forward(self, x): + x = self.avg(x) + return torch.cat((x, x.mul(0)), 1) + + +class DownsampleB(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleB, self).__init__() + self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) + self.bn = nn.BatchNorm2d(nOut) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class DownsampleC(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleC, self).__init__() + assert stride != 1 or nIn != nOut + self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) + + def forward(self, x): + x = self.conv(x) + return x + + +class DownsampleD(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleD, self).__init__() + assert stride == 2 + self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) + self.bn = nn.BatchNorm2d(nOut) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ResNetBasicblock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, last=False): + super(ResNetBasicblock, self).__init__() + + self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(planes) + + self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + + self.downsample = downsample + self.last = last + + def forward(self, x): + residual = x + + basicblock = self.conv_a(x) + basicblock = self.bn_a(basicblock) + basicblock = F.relu(basicblock, inplace=True) + + basicblock = self.conv_b(basicblock) + basicblock = self.bn_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(x) + + out = residual + basicblock + if not self.last: + out = F.relu(out, inplace=True) + + return out + + +class CifarResNet(nn.Module): + """ + ResNet optimized for the Cifar Dataset, as specified in + https://arxiv.org/abs/1512.03385.pdf + """ + + def __init__(self, block, depth, channels=3): + super(CifarResNet, self).__init__() + + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' + layer_blocks = (depth - 2) // 6 + + self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_1 = nn.BatchNorm2d(16) + + self.inplanes = 16 + self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) + self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) + self.stage_3 = self._make_layer(block, 64, layer_blocks, 2, last_phase=True) + self.avgpool = nn.AvgPool2d(8) + self.out_dim = 64 * block.expansion + # self.fc = CosineLinear(64*block.expansion, 10) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, last_phase=False): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = DownsampleB(self.inplanes, planes * block.expansion, stride) # DownsampleA => DownsampleB + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + if last_phase: + for i in range(1, blocks-1): + layers.append(block(self.inplanes, planes)) + layers.append(block(self.inplanes, planes, last=True)) + else: + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv_1_3x3(x) # [bs, 16, 32, 32] + x = F.relu(self.bn_1(x), inplace=True) + + x_1 = self.stage_1(x) # [bs, 16, 32, 32] + x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] + x_3 = self.stage_3(x_2) # [bs, 64, 8, 8] + + pooled = self.avgpool(x_3) # [bs, 64, 1, 1] + features = pooled.view(pooled.size(0), -1) # [bs, 64] + # out = self.fc(vector) + + return { + 'fmaps': [x_1, x_2, x_3], + 'features': features + } + + @property + def last_conv(self): + return self.stage_3[-1].conv_b + + +def resnet20mnist(): + """Constructs a ResNet-20 model for MNIST.""" + model = CifarResNet(ResNetBasicblock, 20, 1) + return model + + +def resnet32mnist(): + """Constructs a ResNet-32 model for MNIST.""" + model = CifarResNet(ResNetBasicblock, 32, 1) + return model + + +def resnet20(): + """Constructs a ResNet-20 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 20) + return model + + +def resnet32(): + """Constructs a ResNet-32 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 32) + return model + + +def resnet44(): + """Constructs a ResNet-44 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 44) + return model + + +def resnet56(): + """Constructs a ResNet-56 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 56) + return model + + +def resnet110(): + """Constructs a ResNet-110 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 110) + return model diff --git a/convs/ucir_resnet.py b/convs/ucir_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..73b4dbb32e6ad4a5b7b8ff16c47efa9f7ea31740 --- /dev/null +++ b/convs/ucir_resnet.py @@ -0,0 +1,299 @@ +''' +Reference: +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +''' +import torch +import torch.nn as nn +try: + from torchvision.models.utils import load_state_dict_from_url +except: + from torch.hub import load_state_dict_from_url + +__all__ = ['resnet50'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, last=False): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + self.last = last + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + if not self.last: + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, last=False): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.last = last + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + if not self.last: + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, args=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + + assert args is not None, "you should pass args to resnet" + if 'cifar' in args["dataset"]: + self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) + elif 'imagenet' in args["dataset"] or 'stanfordcar' in args["dataset"] or 'general_dataset' in args['dataset']: + if args["init_cls"] == args["increment"]: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2], last_phase=True) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.out_dim = 512 * block.expansion + self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_phase=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + if last_phase: + for _ in range(1, blocks-1): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer, last=True)) + else: + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) # [bs, 64, 32, 32] + + x_1 = self.layer1(x) # [bs, 128, 32, 32] + x_2 = self.layer2(x_1) # [bs, 256, 16, 16] + x_3 = self.layer3(x_2) # [bs, 512, 8, 8] + x_4 = self.layer4(x_3) # [bs, 512, 4, 4] + + pooled = self.avgpool(x_4) # [bs, 512, 1, 1] + features = torch.flatten(pooled, 1) # [bs, 512] + # x = self.fc(x) + + return { + 'fmaps': [x_1, x_2, x_3, x_4], + 'features': features + } + + def forward(self, x): + return self._forward_impl(x) + + @property + def last_conv(self): + if hasattr(self.layer4[-1], 'conv3'): + return self.layer4[-1].conv3 + else: + return self.layer4[-1].conv2 + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) diff --git a/download_dataset.sh b/download_dataset.sh new file mode 100644 index 0000000000000000000000000000000000000000..be8d19530c2f30e2a49872edc496e488c5c978a1 --- /dev/null +++ b/download_dataset.sh @@ -0,0 +1,8 @@ +#!/bin/sh +kaggle datasets download -d senemanu/stanfordcarsfcs + +unzip -qq stanfordcarsfcs.zip + +rm -rf ./car_data/car_data/train/models + +mv ./car_data/car_data/test ./car_data/car_data/val diff --git a/download_file_from_s3.py b/download_file_from_s3.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c1ce923aa1d1d18f2ef4706ac7f92ed751c34e --- /dev/null +++ b/download_file_from_s3.py @@ -0,0 +1,49 @@ +import os +import boto3 +from botocore.exceptions import NoCredentialsError + + +def download_from_s3(bucket_name, s3_key, local_path, is_directory=False): + """ + Download a file or directory from S3 to a local path. + + :param bucket_name: str. The name of the S3 bucket. + :param s3_key: str. The S3 key (path to the file or directory). + :param local_path: str. The local file path or directory to download to. + :param is_directory: bool. Set to True if s3_key is a directory. + """ + s3 = boto3.client("s3") + + if is_directory: + # Ensure the local directory exists + if not os.path.exists(local_path): + os.makedirs(local_path) + + # List all objects in the specified S3 directory + result = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_key) + print(result) + + if "Contents" in result: + for obj in result["Contents"]: + s3_object_key = obj["Key"] + # Remove the directory prefix to get the relative file path + relative_path = os.path.relpath(s3_object_key, s3_key) + local_file_path = os.path.join(local_path, relative_path) + + # Ensure the local directory for the file exists + local_file_dir = os.path.dirname(local_file_path) + if not os.path.exists(local_file_dir): + os.makedirs(local_file_dir) + + # Download the file + s3.download_file(bucket_name, s3_object_key, local_file_path) + print(f"Downloaded {s3_object_key} to {local_file_path}") + else: + # Download a single file + print(f"Downloaded {s3_key} to {local_path}") + s3.download_file(bucket_name, s3_key, local_path) + + +# Example usage: +# download_from_s3('my-bucket', 'path/to/myfile.txt', 'local/path/to/myfile.txt') +# download_from_s3('my-bucket', 'path/to/mydirectory/', 'local/path/to/mydirectory', is_directory=True) diff --git a/download_s3_path.py b/download_s3_path.py new file mode 100644 index 0000000000000000000000000000000000000000..e103bf7a1cb6a92172027e5135e79eb0b08b819a --- /dev/null +++ b/download_s3_path.py @@ -0,0 +1,58 @@ +import os +import boto3 +from botocore.exceptions import NoCredentialsError, PartialCredentialsError + +def download_s3_folder(bucket_name, s3_folder, local_dir): + # Convert local_dir to an absolute path + local_dir = os.path.abspath(local_dir) + + # Ensure local directory exists + if not os.path.exists(local_dir): + os.makedirs(local_dir, exist_ok=True) + + s3 = boto3.client('s3') + + try: + # List objects within the specified folder + objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder) + if 'Contents' not in objects: + print(f"The folder '{s3_folder}' does not contain any files.") + return + + for obj in objects['Contents']: + # Formulate the local file path + s3_file_path = obj['Key'] + if s3_file_path.endswith('/'): + # Skip directories + continue + + local_file_path = os.path.join(local_dir, os.path.relpath(s3_file_path, s3_folder)) + + # Create local directories if they do not exist + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + + # Download the file + s3.download_file(bucket_name, s3_file_path, local_file_path) + print(f'Downloaded {s3_file_path} to {local_file_path}') + + except KeyError: + print(f"The folder '{s3_folder}' does not contain any files.") + except NoCredentialsError: + print("Credentials not available.") + except PartialCredentialsError: + print("Incomplete credentials provided.") + except PermissionError as e: + print(f"Permission error: {e}. Please check your directory permissions.") + except Exception as e: + print(f"An error occurred: {e}") + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Download an S3 folder to a local directory.') + parser.add_argument('-bucket', type=str, required=True, help='The S3 bucket name.') + parser.add_argument('-s3_folder', type=str, required=True, help='The folder path within the S3 bucket.') + parser.add_argument('-local_dir', type=str, required=True, help='The local directory to download the files to.') + args = parser.parse_args() + + download_s3_folder(args.bucket, args.s3_folder, args.local_dir) diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..4fecceffa04b4d5f7024dd628499eac7fdc56875 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,8 @@ +#!/bin/sh +set -e + +chmod +x train.sh install_awscli.sh + +mkdir upload + +python server.py diff --git a/eval.py b/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..a162c70e147cc4ae3d79dc579d69e46de14deb82 --- /dev/null +++ b/eval.py @@ -0,0 +1,133 @@ +import sys +import logging +import copy +import torch +from PIL import Image +import torchvision.transforms as transforms +from utils import factory +from utils.data_manager import DataManager +from torch.utils.data import DataLoader +from utils.toolkit import count_parameters +import os +import numpy as np +import json +import argparse +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') +def _set_device(args): + device_type = args["device"] + gpus = [] + + for device in device_type: + if device == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device)) + + gpus.append(device) + + args["device"] = gpus + +def get_methods(object, spacing=20): + methodList = [] + for method_name in dir(object): + try: + if callable(getattr(object, method_name)): + methodList.append(str(method_name)) + except Exception: + methodList.append(str(method_name)) + processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s) + for method in methodList: + try: + print(str(method.ljust(spacing)) + ' ' + + processFunc(str(getattr(object, method).__doc__)[0:90])) + except Exception: + print(method.ljust(spacing) + ' ' + ' getattr() failed') + +def load_model(args): + _set_device(args) + model = factory.get_model(args["model_name"], args) + model.load_checkpoint(args["checkpoint"]) + return model + +def evaluate(args): + logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], args['init_cls'], args['increment']) + + if not os.path.exists(logs_name): + os.makedirs(logs_name) + logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format( + args["model_name"], + args["dataset"], + args['data'], + args['init_cls'], + args["increment"], + args["prefix"], + args["seed"], + args["convnet_type"], + ) + if not os.path.exists(logs_name): + os.makedirs(logs_name) + args['logfilename'] = logs_name + args['csv_name'] = "{}_{}_{}".format( + args["prefix"], + args["seed"], + args["convnet_type"], + ) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(filename)s] => %(message)s", + handlers=[ + logging.FileHandler(filename=logfilename + ".log"), + logging.StreamHandler(sys.stdout), + ], + ) + _set_random() + print_args(args) + model = load_model(args) + + data_manager = DataManager( + args["dataset"], + False, + args["seed"], + args["init_cls"], + args["increment"], + path = args["data"] + ) + loader = DataLoader(data_manager.get_dataset(model.class_list, source = "test", mode = "test"), batch_size=args['batch_size'], shuffle=True, num_workers=8) + + cnn_acc, nme_acc = model.eval_task(loader, group = 1, mode = "test") + print(cnn_acc, nme_acc) +def main(): + args = setup_parser().parse_args() + param = load_json(args.config) + args = vars(args) # Converting argparse Namespace to a dict. + args.update(param) # Add parameters from json + evaluate(args) + +def load_json(settings_path): + with open(settings_path) as data_file: + param = json.load(data_file) + + return param + +def _set_random(): + torch.manual_seed(1) + torch.cuda.manual_seed(1) + torch.cuda.manual_seed_all(1) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def setup_parser(): + parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') + parser.add_argument('--config', type=str, default='./exps/finetune.json', + help='Json file of settings.') + parser.add_argument('-d','--data', type=str, help='Path of the data folder') + parser.add_argument('-c','--checkpoint', type=str, help='Path of checkpoint file if resume training') + return parser + +def print_args(args): + for key, value in args.items(): + logging.info("{}: {}".format(key, value)) +if __name__ == '__main__': + main() + diff --git a/exps/beef.json b/exps/beef.json new file mode 100644 index 0000000000000000000000000000000000000000..c28195c56e11dc54cbb6a4832af30faeb738c5cf --- /dev/null +++ b/exps/beef.json @@ -0,0 +1,28 @@ +{ + "prefix": "fusion-energy-0.01-1.7-fixed", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "beefiso", + "convnet_type": "resnet18", + "device": ["0", "1"], + "seed": [2003], + "logits_alignment": 1.7, + "energy_weight": 0.01, + "is_compress":false, + "reduce_batch_size": false, + "init_epochs": 1, + "init_lr" : 0.1, + "init_weight_decay" : 5e-4, + "expansion_epochs" : 1, + "fusion_epochs" : 1, + "lr" : 0.1, + "batch_size" : 32, + "weight_decay" : 5e-4, + "num_workers" : 8, + "T" : 2 +} \ No newline at end of file diff --git a/exps/bic.json b/exps/bic.json new file mode 100644 index 0000000000000000000000000000000000000000..6510ef908d06c6d55932ebb68e04da33296ff3d5 --- /dev/null +++ b/exps/bic.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "cifar100", + "memory_size": 2000, + "memory_per_class": 20, + "fixed_memory": false, + "shuffle": true, + "init_cls": 10, + "increment": 10, + "model_name": "bic", + "convnet_type": "resnet32", + "device": ["0","1","2","3"], + "seed": [1993] +} diff --git a/exps/coil.json b/exps/coil.json new file mode 100644 index 0000000000000000000000000000000000000000..98a3d048bda89a74dcd09ddacbe6622f46c7b4fe --- /dev/null +++ b/exps/coil.json @@ -0,0 +1,18 @@ +{ + "prefix": "reproduce", + "dataset": "stanfordcar", + "memory_size": 2000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "sinkhorn":0.464, + "calibration_term":1.5, + "norm_term":3.0, + "reg_term":1e-3, + "model_name": "coil", + "convnet_type": "cosine_resnet18", + "device": ["0","1"], + "seed": [2003] +} diff --git a/exps/der.json b/exps/der.json new file mode 100644 index 0000000000000000000000000000000000000000..a3ade0b06b91e7970c5e94200f1cbf8ad5685f7b --- /dev/null +++ b/exps/der.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "der", + "convnet_type": "resnet18", + "device": ["0","1"], + "seed": [1993] +} \ No newline at end of file diff --git a/exps/ewc.json b/exps/ewc.json new file mode 100644 index 0000000000000000000000000000000000000000..76d865417f23b15ac62b191c49ca5b6edd0a765f --- /dev/null +++ b/exps/ewc.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "cifar100", + "memory_size": 2000, + "memory_per_class": 20, + "fixed_memory": false, + "shuffle": true, + "init_cls": 10, + "increment": 10, + "model_name": "ewc", + "convnet_type": "resnet32", + "device": ["0","1","2","3"], + "seed": [1993] +} \ No newline at end of file diff --git a/exps/fetril.json b/exps/fetril.json new file mode 100644 index 0000000000000000000000000000000000000000..b922528da2dd0afb85364acb036fcd9acc697b5d --- /dev/null +++ b/exps/fetril.json @@ -0,0 +1,21 @@ +{ + "prefix": "train", + "dataset": "stanfordcar", + "memory_size": 0, + "shuffle": true, + "init_cls": 40, + "increment": 1, + "model_name": "fetril", + "convnet_type": "resnet18", + "device": ["0"], + "seed": [2003], + "init_epochs": 100, + "init_lr" : 0.1, + "init_weight_decay" : 5e-4, + "epochs" : 80, + "lr" : 0.05, + "batch_size" : 32, + "weight_decay" : 5e-4, + "num_workers" : 8, + "T" : 2 +} \ No newline at end of file diff --git a/exps/finetune.json b/exps/finetune.json new file mode 100644 index 0000000000000000000000000000000000000000..f0c5a5cdc225009dc5dbe818aa3952e9cd519dad --- /dev/null +++ b/exps/finetune.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "finetune", + "convnet_type": "resnet18", + "device": ["0"], + "seed": [2003] +} \ No newline at end of file diff --git a/exps/foster.json b/exps/foster.json new file mode 100644 index 0000000000000000000000000000000000000000..dfbcb38ce4cca206ee0fabe68068dd461d91ed83 --- /dev/null +++ b/exps/foster.json @@ -0,0 +1,31 @@ +{ + "prefix": "cil", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "foster", + "convnet_type": "resnet18", + "device": ["0"], + "seed": [2003], + "beta1":0.96, + "beta2":0.97, + "oofc":"ft", + "is_teacher_wa":false, + "is_student_wa":false, + "lambda_okd":1, + "wa_value":1, + "init_epochs": 100, + "init_lr" : 0.1, + "init_weight_decay" : 5e-4, + "boosting_epochs" : 80, + "compression_epochs" : 50, + "lr" : 0.1, + "batch_size" : 32, + "weight_decay" : 5e-4, + "num_workers" : 8, + "T" : 2 +} \ No newline at end of file diff --git a/exps/foster_general.json b/exps/foster_general.json new file mode 100644 index 0000000000000000000000000000000000000000..47eebb6f73bc1c62ccef639c7a43bf4da0828595 --- /dev/null +++ b/exps/foster_general.json @@ -0,0 +1,31 @@ +{ + "prefix": "cil", + "dataset": "general_dataset", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "foster", + "convnet_type": "resnet18", + "device": ["0"], + "seed": [2003], + "beta1":0.96, + "beta2":0.97, + "oofc":"ft", + "is_teacher_wa":false, + "is_student_wa":false, + "lambda_okd":1, + "wa_value":1, + "init_epochs": 100, + "init_lr" : 0.1, + "init_weight_decay" : 5e-4, + "boosting_epochs" : 80, + "compression_epochs" : 50, + "lr" : 0.1, + "batch_size" : 32, + "weight_decay" : 5e-4, + "num_workers" : 8, + "T" : 2 +} \ No newline at end of file diff --git a/exps/gem.json b/exps/gem.json new file mode 100644 index 0000000000000000000000000000000000000000..ddec0a3f41761ff94b17ed6b27711cc8ab7bce07 --- /dev/null +++ b/exps/gem.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "gem", + "convnet_type": "resnet18", + "device": [ "0", "1"], + "seed": [2003] +} \ No newline at end of file diff --git a/exps/icarl.json b/exps/icarl.json new file mode 100644 index 0000000000000000000000000000000000000000..2129645841f9dc322646b4b3583475baf02b4853 --- /dev/null +++ b/exps/icarl.json @@ -0,0 +1,15 @@ +{ + "prefix": "reproduce", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "icarl", + "convnet_type": "resnet18", + "device": ["0"], + "seed": [2003] +} + diff --git a/exps/il2a.json b/exps/il2a.json new file mode 100644 index 0000000000000000000000000000000000000000..c644999c7a42b31abf2215ee614398fa89ad3f2a --- /dev/null +++ b/exps/il2a.json @@ -0,0 +1,24 @@ +{ + "prefix": "cil", + "dataset": "stanfordcar", + "memory_size": 0, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "il2a", + "convnet_type": "resnet18_cbam", + "device": ["0", "1"], + "seed": [2003], + "lambda_fkd":10, + "lambda_proto":10, + "temp":0.1, + "epochs" : 1, + "lr" : 0.001, + "batch_size" : 32, + "weight_decay" : 2e-4, + "step_size":45, + "gamma":0.1, + "num_workers" : 8, + "ratio": 2.5, + "T" : 2 +} \ No newline at end of file diff --git a/exps/lwf.json b/exps/lwf.json new file mode 100644 index 0000000000000000000000000000000000000000..0d55f4e0c0065b5e5345ccc98588211b65628b32 --- /dev/null +++ b/exps/lwf.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 10, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "lwf", + "convnet_type": "resnet18", + "device":["0", "1"], + "seed": [2003] +} diff --git a/exps/memo.json b/exps/memo.json new file mode 100644 index 0000000000000000000000000000000000000000..c8eef1668374c0eb8470151beb5b979efc193d9a --- /dev/null +++ b/exps/memo.json @@ -0,0 +1,33 @@ +{ + "prefix": "benchmark", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class":20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "memo", + "convnet_type": "memo_resnet18", + "train_base": true, + "train_adaptive": true, + "debug": false, + "skip": false, + "device": ["0", "1"], + "seed":[2003], + "scheduler": "steplr", + "init_epoch": 100, + "t_max": null, + "init_lr" : 0.1, + "init_weight_decay" : 5e-4, + "init_lr_decay" : 0.1, + "init_milestones" : [40,60,80], + "milestones" : [30,50,70], + "epochs": 80, + "lrate" : 0.1, + "batch_size" : 32, + "weight_decay" : 2e-4, + "lrate_decay" : 0.1, + "alpha_aux" : 1.0, + "backbone" : "models/finetune/reproduce_2003_resnet18_9.pkl" +} \ No newline at end of file diff --git a/exps/pass.json b/exps/pass.json new file mode 100644 index 0000000000000000000000000000000000000000..b82509b11635b23bf6aa86f5db0ee7c0c8ccec29 --- /dev/null +++ b/exps/pass.json @@ -0,0 +1,23 @@ +{ + "prefix": "train", + "dataset": "stanfordcar", + "memory_size": 0, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "pass", + "convnet_type": "resnet18_cbam", + "device": ["0"], + "seed": [2003], + "lambda_fkd":10, + "lambda_proto":10, + "temp":0.1, + "epochs" : 100, + "lr" : 0.001, + "batch_size" : 16, + "weight_decay" : 2e-4, + "step_size":45, + "gamma":0.1, + "num_workers" : 8, + "T" : 2 +} \ No newline at end of file diff --git a/exps/podnet.json b/exps/podnet.json new file mode 100644 index 0000000000000000000000000000000000000000..d33f8c3de36e474d1918b9bf5b0f7f6aefefe4ae --- /dev/null +++ b/exps/podnet.json @@ -0,0 +1,14 @@ +{ + "prefix": "increment", + "dataset": "stanfordcar", + "memory_size": 2000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "podnet", + "convnet_type": "cosine_resnet18", + "device": ["0","1"], + "seed": [2003] +} diff --git a/exps/replay.json b/exps/replay.json new file mode 100644 index 0000000000000000000000000000000000000000..11c8ce967040d3d4072bb821028af16f0ee973ac --- /dev/null +++ b/exps/replay.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "stanfordcar", + "memory_size": 4000, + "memory_per_class": 20, + "fixed_memory": true, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "replay", + "convnet_type": "resnet18", + "device": ["0"], + "seed": [2003] +} \ No newline at end of file diff --git a/exps/rmm-foster.json b/exps/rmm-foster.json new file mode 100644 index 0000000000000000000000000000000000000000..671d4d48ed32517ee743b578e462d53215c5e5bd --- /dev/null +++ b/exps/rmm-foster.json @@ -0,0 +1,31 @@ +{ + "prefix": "rmm-foster", + "dataset": "stanfordcar", + "memory_size": 2000, + "m_rate_list":[0.3, 0.3, 0.3, 0.4, 0.4, 0.4], + "c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0], + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "rmm-foster", + "convnet_type": "resnet18", + "device": ["0", "1"], + "seed": [2003], + "beta1":0.97, + "beta2":0.97, + "oofc":"ft", + "is_teacher_wa":false, + "is_student_wa":false, + "lambda_okd":1, + "wa_value":1, + "init_epochs": 1, + "init_lr" : 0.1, + "init_weight_decay" : 5e-4, + "boosting_epochs" : 1, + "compression_epochs" : 1, + "lr" : 0.1, + "batch_size" : 32, + "weight_decay" : 5e-4, + "num_workers" : 8, + "T" : 2 +} diff --git a/exps/rmm-icarl.json b/exps/rmm-icarl.json new file mode 100644 index 0000000000000000000000000000000000000000..d117fcbfc4d9d8d9c49e3a448947c5e94a6673a5 --- /dev/null +++ b/exps/rmm-icarl.json @@ -0,0 +1,15 @@ +{ + "prefix": "reproduce", + "dataset": "cifar100", + "m_rate_list":[0.8, 0.8, 0.6, 0.6, 0.6, 0.6], + "c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0], + "memory_size": 2000, + "shuffle": true, + "init_cls": 50, + "increment": 10, + "model_name": "rmm-icarl", + "convnet_type": "resnet32", + "device": ["0"], + "seed": [1993] +} + diff --git a/exps/rmm-pretrain.json b/exps/rmm-pretrain.json new file mode 100644 index 0000000000000000000000000000000000000000..14b92ceb178ec413b1810dc60627100c7693ad6a --- /dev/null +++ b/exps/rmm-pretrain.json @@ -0,0 +1,10 @@ +{ + "prefix": "pretrain-rmm", + "dataset": "cifar100", + "memory_size": 2000, + "shuffle": true, + "model_name": "rmm-icarl", + "convnet_type": "resnet32", + "device": ["0"], + "seed": [1993] +} diff --git a/exps/simplecil.json b/exps/simplecil.json new file mode 100644 index 0000000000000000000000000000000000000000..b037f432151deaf9e3eff0074d3546c3be80bb17 --- /dev/null +++ b/exps/simplecil.json @@ -0,0 +1,23 @@ +{ + "prefix": "simplecil", + "dataset": "stanfordcar", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 50, + "increment": 20, + "model_name": "simplecil", + "convnet_type": "cosine_resnet18", + "device": ["0"], + "seed": [2003], + "checkpoint": "./models/simplecil/stanfordcar/0/20/simplecil_0.pkl", + "init_epoch": 1, + "init_lr": 0.01, + "batch_size": 32, + "weight_decay": 0.05, + "init_lr_decay": 0.1, + "init_weight_decay": 5e-4, + "min_lr": 0 +} + diff --git a/exps/simplecil_general.json b/exps/simplecil_general.json new file mode 100644 index 0000000000000000000000000000000000000000..96acb01c8cd916e103623457deb112ef197ab4bc --- /dev/null +++ b/exps/simplecil_general.json @@ -0,0 +1,22 @@ +{ + "prefix": "simplecil", + "dataset": "general_dataset", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "simplecil", + "convnet_type": "cosine_resnet18", + "device": [-1], + "seed": [2003], + "init_epoch": 1, + "init_lr": 0.01, + "batch_size": 32, + "weight_decay": 0.05, + "init_lr_decay": 0.1, + "init_weight_decay": 5e-4, + "min_lr": 0 +} + diff --git a/exps/simplecil_resume.json b/exps/simplecil_resume.json new file mode 100644 index 0000000000000000000000000000000000000000..cd18739955d5b9366d860b10328e489417e4c6dd --- /dev/null +++ b/exps/simplecil_resume.json @@ -0,0 +1,24 @@ +{ + "prefix": "simplecil", + "dataset": "general_dataset", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 50, + "increment": 20, + "model_name": "simplecil", + "convnet_type": "cosine_resnet18", + "device": ["0"], + "seed": [2003], + "checkpoint": "./models/simplecil/stanfordcar/50/20/simplecil_0.pkl", + "data": "./car_data/car_data", + "init_epoch": 1, + "init_lr": 0.01, + "batch_size": 32, + "weight_decay": 0.05, + "init_lr_decay": 0.1, + "init_weight_decay": 5e-4, + "min_lr": 0 +} + diff --git a/exps/ssre.json b/exps/ssre.json new file mode 100644 index 0000000000000000000000000000000000000000..d9f6935cf7d8f4dcc23effbaf20b7bfee15772da --- /dev/null +++ b/exps/ssre.json @@ -0,0 +1,25 @@ +{ + "prefix": "ssre", + "dataset": "stanfordcar", + "memory_size": 0, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "ssre", + "convnet_type": "resnet18_rep", + "device": ["0"], + "seed": [2003], + "lambda_fkd":1, + "lambda_proto":10, + "temp":0.1, + "mode": "parallel_adapters", + "epochs" : 1, + "lr" : 0.0001, + "batch_size" : 32, + "weight_decay" : 5e-4, + "step_size":45, + "gamma":0.1, + "threshold": 0.8, + "num_workers" : 8, + "T" : 2 +} \ No newline at end of file diff --git a/exps/wa.json b/exps/wa.json new file mode 100644 index 0000000000000000000000000000000000000000..16e2ca86e9d7700a10a9738bd1dc7a1910fa0a3f --- /dev/null +++ b/exps/wa.json @@ -0,0 +1,14 @@ +{ + "prefix": "reproduce", + "dataset": "cifar100", + "memory_size": 2000, + "memory_per_class": 20, + "fixed_memory": false, + "shuffle": true, + "init_cls": 10, + "increment": 10, + "model_name": "wa", + "convnet_type": "resnet32", + "device": ["0","1","2","3"], + "seed": [1993] +} \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..207c4ebf35e1c2670d8c548732c2ebc5a7c46eb2 --- /dev/null +++ b/inference.py @@ -0,0 +1,115 @@ +import sys +import logging +import copy +import torch +from PIL import Image +import torchvision.transforms as transforms +from torchvision.transforms.functional import pil_to_tensor +from utils import factory +from utils.data_manager import DataManager +from utils.toolkit import count_parameters +from utils.data_manager import pil_loader +import os +import numpy as np +import json +import argparse +import imghdr +import time + +def is_image_imghdr(path): + """ + Checks if a path points to a valid image using imghdr. + + Args: + path: The path to the file. + + Returns: + True if the path is a valid image, False otherwise. + """ + if not os.path.isfile(path): + return False + return imghdr.what(path) in ['jpeg', 'png'] + +def _set_device(args): + device_type = args["device"] + gpus = [] + + for device in device_type: + if device == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device)) + + gpus.append(device) + + args["device"] = gpus + +def get_methods(object, spacing=20): + methodList = [] + for method_name in dir(object): + try: + if callable(getattr(object, method_name)): + methodList.append(str(method_name)) + except Exception: + methodList.append(str(method_name)) + processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s) + for method in methodList: + try: + print(str(method.ljust(spacing)) + ' ' + + processFunc(str(getattr(object, method).__doc__)[0:90])) + except Exception: + print(method.ljust(spacing) + ' ' + ' getattr() failed') + +def load_model(args): + _set_device(args) + model = factory.get_model(args["model_name"], args) + model.load_checkpoint(args["checkpoint"]) + return model +def main(): + args = setup_parser().parse_args() + param = load_json(args.config) + args = vars(args) # Converting argparse Namespace to a dict. + args.update(param) # Add parameters from json + assert args['output'].split(".")[-1] == "json" or os.path.isdir(args['output']) + model = load_model(args) + result = [] + if is_image_imghdr(args['input']): + img = pil_to_tensor(pil_loader(args['input'])) + img = img.unsqueeze(0) + predictions = model.inference(img) + out = {"img": args['input'].split("/")[-1]} + out.update({"predictions": [{"confident": confident, "index": pred, "label": label } for pred, label, confident in zip(predictions[0], predictions[1], predictions[2])]}) + result.append(out) + else: + image_list = filter(lambda x: is_image_imghdr(os.path.join(args['input'], x)), os.listdir(args['input'])) + for image in image_list: + print("Inference on image", image) + img = pil_to_tensor(pil_loader(os.path.join(args['input'], image))) + img = img.unsqueeze(0) + predictions = model.inference(img) + out = {"img": image.split("/")[-1]} + out.update({"predictions": [{"confident": confident, "index": pred, "label": label } for pred, label, confident in zip(predictions[0], predictions[1], predictions[2])]}) + result.append(out) + if args['output'].split(".")[-1] == "json": + with open(args['output'], "w+") as f: + json.dump(result, f, indent=4) + else: + with open(os.path.join(args['output'], "output_model_{}.json".format(time.time())), "w+") as f: + json.dump(result, f, indent=4) +def load_json(settings_path): + with open(settings_path) as data_file: + param = json.load(data_file) + return param + + +def setup_parser(): + parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') + parser.add_argument('--config', type=str, help='Json file of settings.') + parser.add_argument('--checkpoint', type=str, help="path to checkpoint file. File must be a .pth format file") + parser.add_argument('--input', type=str, help="Path to input. This could be an folder or an image file") + parser.add_argument('--output', type=str, help = "Output path to save prediction") + return parser + +if __name__ == '__main__': + main() + diff --git a/install_awscli.sh b/install_awscli.sh new file mode 100644 index 0000000000000000000000000000000000000000..dc0acc69c82403826c436e30b1caea9258e2322f --- /dev/null +++ b/install_awscli.sh @@ -0,0 +1,7 @@ +#!/bin/sh + +curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" + +unzip awscliv2.zip + +./aws/install diff --git a/load.sh b/load.sh new file mode 100644 index 0000000000000000000000000000000000000000..429b50227fb951e7be17231d8d687ecfc08935d2 --- /dev/null +++ b/load.sh @@ -0,0 +1,5 @@ +#! /bin/sh +for arg in $@; do + python ./load_model.py --config=$arg; + # Your commands to process each argument here +done \ No newline at end of file diff --git a/load_model.py b/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2b4e02b20ec76c1d47a7dfce9fad32ef6ce2aa --- /dev/null +++ b/load_model.py @@ -0,0 +1,73 @@ +import sys +import logging +import copy +import torch +from PIL import Image +import torchvision.transforms as transforms +from utils import factory +from utils.data_manager import DataManager +from utils.toolkit import count_parameters +import os +import numpy as np +import json +import argparse + +def _set_device(args): + device_type = args["device"] + gpus = [] + + for device in device_type: + if device == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device)) + + gpus.append(device) + + args["device"] = gpus + +def get_methods(object, spacing=20): + methodList = [] + for method_name in dir(object): + try: + if callable(getattr(object, method_name)): + methodList.append(str(method_name)) + except Exception: + methodList.append(str(method_name)) + processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s) + for method in methodList: + try: + print(str(method.ljust(spacing)) + ' ' + + processFunc(str(getattr(object, method).__doc__)[0:90])) + except Exception: + print(method.ljust(spacing) + ' ' + ' getattr() failed') + +def load_model(args): + _set_device(args) + model = factory.get_model(args["model_name"], args) + model.load_checkpoint(args["checkpoint"]) + return model +def main(): + args = setup_parser().parse_args() + param = load_json(args.config) + args = vars(args) # Converting argparse Namespace to a dict. + args.update(param) # Add parameters from json + + load_model(args) +def load_json(settings_path): + with open(settings_path) as data_file: + param = json.load(data_file) + + return param + + +def setup_parser(): + parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') + parser.add_argument('--config', type=str, default='./exps/finetune.json', + help='Json file of settings.') + + return parser + +if __name__ == '__main__': + main() + diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..358266a524f04de2ab57e27ebf67f925d99a43b5 --- /dev/null +++ b/main.py @@ -0,0 +1,38 @@ +import json +import argparse +from trainer import train +from train_more import train_more + +def main(): + args = setup_parser().parse_args() + param = load_json(args.config) + args = vars(args) # Converting argparse Namespace to a dict. + args.update(param) # Add parameters from json + if not args['dataset'] == "general_dataset": + train(args) + else: + assert args['data'] != None + if not args['checkpoint']: + args.pop('checkpoint') + train(args) + else: + train_more(args) + +def load_json(settings_path): + with open(settings_path) as data_file: + param = json.load(data_file) + + return param + + +def setup_parser(): + parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') + parser.add_argument('--config', type=str, default='./exps/finetune.json', + help='Json file of settings.') + parser.add_argument('-d','--data', nargs ='?', type=str, help='Path of the data folder') + parser.add_argument('-c','--checkpoint',nargs = '?', type=str, help='Path of checkpoint file if resume training') + return parser + + +if __name__ == "__main__": + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..349a06d335b1941ea4d7f9d6e9b9e88555c1d8e9 --- /dev/null +++ b/models/base.py @@ -0,0 +1,421 @@ +import copy +import logging +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader +from utils.toolkit import tensor2numpy, accuracy +from scipy.spatial.distance import cdist +import os + +EPSILON = 1e-8 +batch_size = 64 + + +class BaseLearner(object): + def __init__(self, args): + self.args = args + self._cur_task = -1 + self._known_classes = 0 + self._total_classes = 0 + self.class_list = [] + self._network = None + self._old_network = None + self._data_memory, self._targets_memory = np.array([]), np.array([]) + self.topk = 5 + + self._memory_size = args["memory_size"] + self._memory_per_class = args.get("memory_per_class", None) + self._fixed_memory = args.get("fixed_memory", False) + self._device = args["device"][0] + self._multiple_gpus = args["device"] + + @property + def exemplar_size(self): + assert len(self._data_memory) == len( + self._targets_memory + ), "Exemplar size error." + return len(self._targets_memory) + + @property + def samples_per_class(self): + if self._fixed_memory: + return self._memory_per_class + else: + assert self._total_classes != 0, "Total classes is 0" + return self._memory_size // self._total_classes + + @property + def feature_dim(self): + if isinstance(self._network, nn.DataParallel): + return self._network.module.feature_dim + else: + return self._network.feature_dim + + def build_rehearsal_memory(self, data_manager, per_class, ): + if self._fixed_memory: + self._construct_exemplar_unified(data_manager, per_class) + else: + self._reduce_exemplar(data_manager, per_class) + self._construct_exemplar(data_manager, per_class) + def load_checkpoint(self, filename): + pass; + + def save_checkpoint(self, filename): + self._network.cpu() + save_dict = { + "tasks": self._cur_task, + "model_state_dict": self._network.state_dict(), + } + torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task)) + + def after_task(self): + pass + + def _evaluate(self, y_pred, y_true, group = 10): + ret = {} + grouped = accuracy(y_pred.T[0], y_true, self._known_classes, increment = group) + ret["grouped"] = grouped + ret["top1"] = grouped["total"] + ret["top{}".format(self.topk)] = np.around( + (y_pred.T == np.tile(y_true, (self.topk, 1))).sum() * 100 / len(y_true), + decimals=2, + ) + + return ret + + def eval_task(self, data=None, save_conf=False, group = 10, mode = "train"): + if data is None: + data = self.test_loader + y_pred, y_true = self._eval_cnn(data, mode = mode) + cnn_accy = self._evaluate(y_pred, y_true, group = group) + + if hasattr(self, "_class_means"): + y_pred, y_true = self._eval_nme(data, self._class_means) + nme_accy = self._evaluate(y_pred, y_true) + else: + nme_accy = None + + if save_conf: + _pred = y_pred.T[0] + _pred_path = os.path.join(self.args['logfilename'], "pred.npy") + _target_path = os.path.join(self.args['logfilename'], "target.npy") + np.save(_pred_path, _pred) + np.save(_target_path, y_true) + + _save_dir = os.path.join(f"./results/{self.args['model_name']}/conf_matrix/{self.args['prefix']}") + os.makedirs(_save_dir, exist_ok=True) + _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv") + with open(_save_path, "a+") as f: + f.write(f"{self.args['model_name']},{_pred_path},{_target_path} \n") + + return cnn_accy, nme_accy + + def incremental_train(self): + pass + + def _train(self): + pass + + def _get_memory(self): + if len(self._data_memory) == 0: + return None + else: + return (self._data_memory, self._targets_memory) + + def _compute_accuracy(self, model, loader): + model.eval() + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = model(inputs)["logits"] + predicts = torch.max(outputs, dim=1)[1] + correct += (predicts.cpu() == targets).sum() + total += len(targets) + + return np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + def _eval_cnn(self, loader, mode = "train"): + self._network.eval() + y_pred, y_true = [], [] + for _, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = self._network(inputs)["logits"] + if self.topk > self._total_classes: + self.topk = self._total_classes + predicts = torch.topk( + outputs, k=self.topk, dim=1, largest=True, sorted=True + )[ + 1 + ] # [bs, topk] + refine_predicts = predicts.cpu().numpy() + if mode == "test": + refine_predicts = self.class_list[refine_predicts] + y_pred.append(refine_predicts) + y_true.append(targets.cpu().numpy()) + return np.concatenate(y_pred), np.concatenate(y_true) # [N, topk] + def inference(self, image): + self._network.eval() + self._network.to(self._device) + image = image.to(self._device, dtype=torch.float32) + with torch.no_grad(): + output = self._network(image)["logits"] + if self.topk > self._total_classes: + self.topk = self._total_classes + predict = torch.topk( + output, k=self.topk, dim=1, largest=True, sorted=True + )[1] + confidents = softmax(output.cpu().numpy()) + if self.class_list is not None: + self.class_list = np.array(self.class_list) + predicts = predict.cpu().numpy() + result = self.class_list[predicts].tolist() + #result = predicts.tolist() + result.append([self.label_list[item] for item in result[0]]) + result.append(confidents[0][predicts][0].tolist()) + return result + elif self.data_manager is not None: + return self.data_manager.class_list[predict.cpu().numpy()] + + predicts.append([self.label_list[index] for index in predicts[0]]) + return predicts + + def _eval_nme(self, loader, class_means): + self._network.eval() + vectors, y_true = self._extract_vectors(loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + + dists = cdist(class_means, vectors, "sqeuclidean") # [nb_classes, N] + scores = dists.T # [N, nb_classes], choose the one with the smallest distance + + return np.argsort(scores, axis=1)[:, : self.topk], y_true # [N, topk] + + def _extract_vectors(self, loader): + self._network.eval() + vectors, targets = [], [] + for _, _inputs, _targets in loader: + _targets = _targets.numpy() + if isinstance(self._network, nn.DataParallel): + _vectors = tensor2numpy( + self._network.module.extract_vector(_inputs.to(self._device)) + ) + else: + _vectors = tensor2numpy( + self._network.extract_vector(_inputs.to(self._device)) + ) + + vectors.append(_vectors) + targets.append(_targets) + return np.concatenate(vectors), np.concatenate(targets) + + def _reduce_exemplar(self, data_manager, m): + logging.info("Reducing exemplars...({} per classes)".format(m)) + dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy( + self._targets_memory + ) + self._class_means = np.zeros((self._total_classes, self.feature_dim)) + self._data_memory, self._targets_memory = np.array([]), np.array([]) + + for class_idx in range(self._known_classes): + mask = np.where(dummy_targets == class_idx)[0] + dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m] + self._data_memory = ( + np.concatenate((self._data_memory, dd)) + if len(self._data_memory) != 0 + else dd + ) + self._targets_memory = ( + np.concatenate((self._targets_memory, dt)) + if len(self._targets_memory) != 0 + else dt + ) + + # Exemplar mean + idx_dataset = data_manager.get_dataset( + [], source="train", mode="test", appendent=(dd, dt) + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean + + def _construct_exemplar(self, data_manager, m): + logging.info("Constructing exemplars...({} per classes)".format(m)) + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = data_manager.get_dataset( + np.arange(class_idx, class_idx + 1), + source="train", + mode="test", + ret_data=True, + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + + # Select + selected_exemplars = [] + exemplar_vectors = [] # [n, feature_dim] + for k in range(1, m + 1): + S = np.sum( + exemplar_vectors, axis=0 + ) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + selected_exemplars.append( + np.array(data[i]) + ) # New object to avoid passing by inference + exemplar_vectors.append( + np.array(vectors[i]) + ) # New object to avoid passing by inference + + vectors = np.delete( + vectors, i, axis=0 + ) # Remove it to avoid duplicative selection + data = np.delete( + data, i, axis=0 + ) # Remove it to avoid duplicative selection + + # uniques = np.unique(selected_exemplars, axis=0) + # print('Unique elements: {}'.format(len(uniques))) + selected_exemplars = np.array(selected_exemplars) + exemplar_targets = np.full(m, class_idx) + self._data_memory = ( + np.concatenate((self._data_memory, selected_exemplars)) + if len(self._data_memory) != 0 + else selected_exemplars + ) + self._targets_memory = ( + np.concatenate((self._targets_memory, exemplar_targets)) + if len(self._targets_memory) != 0 + else exemplar_targets + ) + + # Exemplar mean + idx_dataset = data_manager.get_dataset( + [], + source="train", + mode="test", + appendent=(selected_exemplars, exemplar_targets), + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean + + def _construct_exemplar_unified(self, data_manager, m): + logging.info( + "Constructing exemplars for new classes...({} per classes)".format(m) + ) + _class_means = np.zeros((self._total_classes, self.feature_dim)) + + # Calculate the means of old classes with newly trained network + for class_idx in range(self._known_classes): + mask = np.where(self._targets_memory == class_idx)[0] + class_data, class_targets = ( + self._data_memory[mask], + self._targets_memory[mask], + ) + + class_dset = data_manager.get_dataset( + [], source="train", mode="test", appendent=(class_data, class_targets) + ) + class_loader = DataLoader( + class_dset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(class_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + _class_means[class_idx, :] = mean + + # Construct exemplars for new classes and calculate the means + for class_idx in range(self._known_classes, self._total_classes): + data, targets, class_dset = data_manager.get_dataset( + np.arange(class_idx, class_idx + 1), + source="train", + mode="test", + ret_data=True, + ) + class_loader = DataLoader( + class_dset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + + vectors, _ = self._extract_vectors(class_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + + # Select + selected_exemplars = [] + exemplar_vectors = [] + for k in range(1, m + 1): + S = np.sum( + exemplar_vectors, axis=0 + ) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + + selected_exemplars.append( + np.array(data[i]) + ) # New object to avoid passing by inference + exemplar_vectors.append( + np.array(vectors[i]) + ) # New object to avoid passing by inference + + vectors = np.delete( + vectors, i, axis=0 + ) # Remove it to avoid duplicative selection + data = np.delete( + data, i, axis=0 + ) # Remove it to avoid duplicative selection + + selected_exemplars = np.array(selected_exemplars) + exemplar_targets = np.full(m, class_idx) + self._data_memory = ( + np.concatenate((self._data_memory, selected_exemplars)) + if len(self._data_memory) != 0 + else selected_exemplars + ) + self._targets_memory = ( + np.concatenate((self._targets_memory, exemplar_targets)) + if len(self._targets_memory) != 0 + else exemplar_targets + ) + + # Exemplar mean + exemplar_dset = data_manager.get_dataset( + [], + source="train", + mode="test", + appendent=(selected_exemplars, exemplar_targets), + ) + exemplar_loader = DataLoader( + exemplar_dset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(exemplar_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + _class_means[class_idx, :] = mean + + self._class_means = _class_means +def softmax(x): + """Compute softmax values for each sets of scores in x.""" + e_x = np.exp(x - np.max(x)) + return e_x / (e_x.sum(axis=0) + 1e-7) # only difference diff --git a/models/beef_iso.py b/models/beef_iso.py new file mode 100644 index 0000000000000000000000000000000000000000..1f72ef56e5525c1c6898441efd13626518856cea --- /dev/null +++ b/models/beef_iso.py @@ -0,0 +1,684 @@ +import copy +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import BEEFISONet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy + + +EPSILON = 1e-8 + + +class BEEFISO(BaseLearner): + def __init__(self, args): + super().__init__(args) + self.args = args + self._network = BEEFISONet(args, False) + self._snet = None + self.logits_alignment = args["logits_alignment"] + self.val_loader = None + self.reduce_batch_size = args["reduce_batch_size"] + self.random = args.get("random",None) + self.imbalance = args.get("imbalance",None) + + def after_task(self): + self._network_module_ptr.update_fc_after() + self._known_classes = self._total_classes + if self.reduce_batch_size: + if self._cur_task == 0: + self.args["batch_size"] = self.args["batch_size"] + else: + self.args["batch_size"] = self.args["batch_size"] * (self._cur_task+1) // (self._cur_task+2) + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self.data_manager = data_manager + self._cur_task += 1 + if self._cur_task > 1 and self.args["is_compress"]: + self._network = self._snet + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc_before(self._total_classes) + self._network_module_ptr = self._network + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + if self._cur_task > 0: + for id in range(self._cur_task): + for p in self._network.convnets[id].parameters(): + p.requires_grad = False + for p in self._network.old_fc.parameters(): + p.requires_grad = False + + + logging.info("All params: {}".format(count_parameters(self._network))) + logging.info( + "Trainable params: {}".format(count_parameters(self._network, True)) + ) + + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + self.train_loader = DataLoader( + train_dataset, + batch_size=self.args["batch_size"], + shuffle=True, + num_workers=self.args["num_workers"], + pin_memory=True, + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=self.args["batch_size"], + shuffle=False, + num_workers=self.args["num_workers"], + pin_memory=True, + ) + if self._cur_task > 0: + if self.random or self.imbalance: + val_dset = data_manager.get_finetune_dataset(known_classes=self._known_classes, total_classes=self._total_classes, + source="train", mode='train', appendent=self._get_memory(), type="ratio") + else: + _, val_dset = data_manager.get_dataset_with_split(np.arange(self._known_classes, self._total_classes), + source='train', mode='train', + appendent=self._get_memory(), + val_samples_per_class=int( + self.samples_old_class)) + self.val_loader = DataLoader( + val_dset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader,self.val_loader) + if self.random or self.imbalance: + self.build_rehearsal_memory_imbalance(data_manager,self.samples_per_class) + else: + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def train(self): + self._network_module_ptr.train() + self._network_module_ptr.convnets[-1].train() + if self._cur_task >= 1: + self._network_module_ptr.convnets[0].eval() + + def _train(self, train_loader, test_loader, val_loader=None): + self._network.to(self._device) + if hasattr(self._network, "module"): + self._network_module_ptr = self._network.module + if self._cur_task == 0: + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + momentum=0.9, + lr=self.args["init_lr"], + weight_decay=self.args["init_weight_decay"], + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args["init_epochs"] + ) + self.epochs = self.args["init_epochs"] + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + lr=self.args["lr"], + momentum=0.9, + weight_decay=self.args["weight_decay"], + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args["expansion_epochs"] + ) + + self.epochs = self.args["expansion_epochs"] + self.state = "expansion" + if len(self._multiple_gpus) > 1: + network = self._network.module + else: + network = self._network + for p in network.biases.parameters(): + p.requires_grad = False + self._expansion(train_loader, test_loader, optimizer, scheduler) + + + + for p in self._network_module_ptr.forward_prototypes.parameters(): + p.requires_grad = False + for p in self._network_module_ptr.backward_prototypes.parameters(): + p.requires_grad = False + for p in self._network_module_ptr.new_fc.parameters(): + p.requires_grad = False + for p in self._network_module_ptr.convnets[-1].parameters(): + p.requires_grad = False + for p in self._network.biases.parameters(): + p.requires_grad = True + self.state = "fusion" + self.epochs = self.args["fusion_epochs"] + self.per_cls_weights = torch.ones(self._total_classes).to(self._device) + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + lr=0.05, + momentum=0.9, + weight_decay=self.args["weight_decay"], + ) + for n, p in self._network.named_parameters(): + if p.requires_grad == True: + print(n) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args["fusion_epochs"] + ) + self._fusion(val_loader,test_loader,optimizer,scheduler) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.epochs)) + for _, epoch in enumerate(prog_bar): + self.train() + losses = 0.0 + losses_en = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True + ), targets.to(self._device, non_blocking=True) + logits = self._network(inputs)["logits"] + loss_en = self.args["energy_weight"] * self.get_energy_loss(inputs,targets,targets) + loss = F.cross_entropy(logits, targets) + loss = loss + loss_en + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_en += loss_en.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["init_epochs"], + losses / len(train_loader), + losses_en / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["init_epochs"], + losses / len(train_loader), + losses_en / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + logging.info(info) + + def _expansion(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.epochs)) + for _, epoch in enumerate(prog_bar): + self.train() + losses = 0.0 + losses_clf = 0.0 + losses_fe = 0.0 + losses_en = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True + ), targets.to(self._device, non_blocking=True) + targets = targets.float() + outputs = self._network(inputs) + logits,train_logits = ( + outputs["logits"], + outputs["train_logits"] + ) + pseudo_targets = targets.clone() + for task_id in range(self._cur_task+1): + if task_id == 0: + pseudo_targets = torch.where(targets0,targets-self._known_classes+task_id,pseudo_targets) + else: + pseudo_targets = torch.where((targetsself.data_manager.get_accumulate_tasksize(task_id-1)-1),task_id,pseudo_targets) + + train_logits[:, list(range(self._cur_task))] /= self.logits_alignment + loss_clf = F.cross_entropy(train_logits.float(), pseudo_targets) + loss_fe = torch.tensor(0.).cuda() + loss_en = self.args["energy_weight"] * self.get_energy_loss(inputs,targets,pseudo_targets) + loss = loss_clf + loss_fe + loss_en + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_fe += loss_fe.item() + losses_clf += loss_clf.item() + losses_en += loss_en.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.epochs, + losses / len(train_loader), + losses_clf / len(train_loader), + losses_fe / len(train_loader), + losses_en / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_en {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.epochs, + losses / len(train_loader), + losses_clf / len(train_loader), + losses_fe / len(train_loader), + losses_en / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + + def _fusion(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.epochs)) + for _, epoch in enumerate(prog_bar): + self.train() + # self. + losses = 0.0 + losses_clf = 0.0 + losses_fe = 0.0 + losses_kd = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True + ), targets.to(self._device, non_blocking=True) + outputs = self._network(inputs) + logits,train_logits = ( + outputs["logits"], + outputs["train_logits"] + ) + + loss_clf = F.cross_entropy(logits,targets) + loss_fe = torch.tensor(0.).cuda() + loss_kd = torch.tensor(0.).cuda() + loss = loss_clf + loss_fe + loss_kd + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_fe += loss_fe.item() + losses_clf += loss_clf.item() + losses_kd += ( + self._known_classes / self._total_classes + ) * loss_kd.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.epochs, + losses / len(train_loader), + losses_clf / len(train_loader), + losses_fe / len(train_loader), + losses_kd / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.epochs, + losses / len(train_loader), + losses_clf / len(train_loader), + losses_fe / len(train_loader), + losses_kd / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + + + @property + def samples_old_class(self): + if self._fixed_memory: + return self._memory_per_class + else: + assert self._total_classes != 0, "Total classes is 0" + return self._memory_size // self._known_classes + + def samples_new_class(self, index): + if self.args["dataset"] == "cifar100": + return 500 + else: + return self.data_manager.getlen(index) + + def BKD(self, pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + soft = soft * self.per_cls_weights + soft = soft / soft.sum(1)[:, None] + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] + + + def get_energy_loss(self,inputs,targets,pseudo_targets): + inputs = self.sample_q(inputs) + + out = self._network(inputs) + if self._cur_task == 0: + targets = targets + self._total_classes + train_logits, energy_logits = out["logits"], out["energy_logits"] + else: + targets = targets + (self._total_classes - self._known_classes) + self._cur_task + train_logits, energy_logits = out["train_logits"], out["energy_logits"] + + logits = torch.cat([train_logits,energy_logits],dim=1) + + logits[:,pseudo_targets] = 1e-9 + energy_loss = F.cross_entropy(logits,targets) + return energy_loss + + def sample_q(self, replay_buffer, n_steps=3): + """this func takes in replay_buffer now so we have the option to sample from + scratch (i.e. replay_buffer==[]). See test_wrn_ebm.py for example. + """ + self._network_copy = self._network_module_ptr.copy().freeze() + init_sample = replay_buffer + init_sample = torch.rot90(init_sample, 2, (2, 3)) + embedding_k = init_sample.clone().detach().requires_grad_(True) + optimizer_gen = torch.optim.SGD( + [embedding_k], lr=1e-2) + for k in range(1, n_steps + 1): + out = self._network_copy(embedding_k) + if self._cur_task == 0: + energy_logits, train_logits = out["energy_logits"], out["logits"] + else: + energy_logits, train_logits = out["energy_logits"], out["train_logits"] + num_forwards = energy_logits.shape[1] + logits = torch.cat([train_logits,energy_logits],dim=1) + negative_energy = torch.log(torch.sum(torch.softmax(logits,dim=1)[:,-num_forwards:])) + optimizer_gen.zero_grad() + negative_energy.sum().backward() + optimizer_gen.step() + embedding_k.data += 1e-3 * \ + torch.randn_like(embedding_k) + final_samples = embedding_k.detach() + return final_samples + + + def build_rehearsal_memory_imbalance(self, data_manager, per_class): + if self._fixed_memory: + self._construct_exemplar_unified_imbalance(data_manager, per_class,self.random,self.imbalance) + else: + self._reduce_exemplar_imbalance(data_manager, per_class,self.random,self.imbalance) + self._construct_exemplar_imbalance(data_manager, per_class,self.random,self.imbalance) + + + def _reduce_exemplar_imbalance(self, data_manager, m,random,imbalance): + logging.info('Reducing exemplars...({} per classes)'.format(m)) + dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy(self._targets_memory) + self._class_means = np.zeros((self._total_classes, self.feature_dim)) + self._data_memory, self._targets_memory = np.array([]), np.array([]) + + for class_idx in range(self._known_classes): + mask = np.where(dummy_targets == class_idx)[0] + l = sum(mask) + if l == 0: + continue + if random or imbalance is not None: + dd, dt = dummy_data[mask][:-1], dummy_targets[mask][:-1] + else: + dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m] + self._data_memory = np.concatenate((self._data_memory, dd)) if len(self._data_memory) != 0 else dd + self._targets_memory = np.concatenate((self._targets_memory, dt)) if len(self._targets_memory) != 0 else dt + + # Exemplar mean + idx_dataset = data_manager.get_dataset([], source='train', mode='test', appendent=(dd, dt)) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean + + def _construct_exemplar_imbalance(self, data_manager, m, random=False,imbalance=None): + increment = self._total_classes - self._known_classes + + if random: + ''' + uniform random type + ''' + selected_exemplars = [] + selected_targets = [] + logging.info("Contructing exmplars, totally random...({} total instances {} classes)".format(increment*m, increment)) + data, targets, idx_dataset = data_manager.get_dataset(np.arange(self._known_classes,self._total_classes),source="train",mode="test",ret_data=True) + selected_indices = np.random.choice(list(range(len(data))),m*increment,repladce=False) + for idx in selected_indices: + selected_exemplars.append(data[idx]) + selected_targets.append(targets[idx]) + selected_exemplars = np.array(selected_exemplars)[:m*increment] + selected_targets = np.array(selected_targets)[:m*increment] + self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ + else selected_exemplars + self._targets_memory = np.concatenate((self._targets_memory, selected_targets)) if \ + len(self._targets_memory) != 0 else selected_targets + else: + if imbalance is None: + logging.info('Constructing exemplars...({} per classes)'.format(m)) + ms = np.ones(increment,dtype=int)*m + elif imbalance>=1: + ''' + half-half type + ''' + ms=[m for _ in range(increment)] + for i in range(increment//2): + ms[i]-=m//imbalance + for i in range(increment//2,increment): + ms[i]+=m//imbalance + np.random.shuffle(ms) + ms = np.array(ms,dtype=int) + logging.info("Constructing exmplars, Imbalance...({} or {} per classes)".format(m-m//imbalance,(m+m//imbalance))) + elif imbalance<1: + ''' + exp type + ''' + ms = np.array([imbalance**i for i in range(increment)]) + ms = ms/ms.sum() + tot = m*increment + ms = (tot*ms).astype(int) + np.random.shuffle(ms) + + else: + assert 0, "not implemented yet" + logging.info("ms {}".format(ms)) + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + + # Select + selected_exemplars = [] + exemplar_vectors = [] # [n, feature_dim] + for k in range(1, ms[class_idx-self._known_classes]+1): + S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference + exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference + + vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection + data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection + + # uniques = np.unique(selected_exemplars, axis=0) + selected_exemplars = np.array(selected_exemplars) + if len(selected_exemplars)==0: + continue + exemplar_targets = np.full(ms[class_idx-self._known_classes], class_idx) + self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ + else selected_exemplars + self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \ + len(self._targets_memory) != 0 else exemplar_targets + + # Exemplar mean + idx_dataset = data_manager.get_dataset([], source='train', mode='test', + appendent=(selected_exemplars, exemplar_targets)) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4,pin_memory=True) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean + # self._class_means[class_idx, :] = class_mean + + def _construct_exemplar_unified_imbalance(self, data_manager, m,random,imbalance): + logging.info('Constructing exemplars for new classes...({} per classes)'.format(m)) + _class_means = np.zeros((self._total_classes, self.feature_dim)) + increment = self._total_classes - self._known_classes + + # Calculate the means of old classes with newly trained network + for class_idx in range(self._known_classes): + mask = np.where(self._targets_memory == class_idx)[0] + if sum(mask) == 0: continue + class_data, class_targets = self._data_memory[mask], self._targets_memory[mask] + + class_dset = data_manager.get_dataset([], source='train', mode='test', + appendent=(class_data, class_targets)) + class_loader = DataLoader(class_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(class_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + _class_means[class_idx, :] = mean + + if random: + ''' + uniform sample type + ''' + selected_exemplars = [] + selected_targets = [] + logging.info("Contructing exmplars, totally random...({} total instances {} classes)".format(increment*m, increment)) + data, targets, idx_dataset = data_manager.get_dataset(np.arange(self._known_classes,self._total_classes),source="train",mode="test",ret_data=True) + selected_indices = np.random.choice(list(range(len(data))),m*increment,replace=False) + for idx in selected_indices: + selected_exemplars.append(data[idx]) + selected_targets.append(targets[idx]) + selected_exemplars = np.array(selected_exemplars) + selected_targets = np.array(selected_targets) + self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ + else selected_exemplars + self._targets_memory = np.concatenate((self._targets_memory, selected_targets)) if \ + len(self._targets_memory) != 0 else selected_targets + else: + if imbalance is None: + logging.info('Constructing exemplars...({} per classes)'.format(m)) + ms = np.ones(increment,dtype=int)*m + elif imbalance>=1: + ''' + half-half type + ''' + ms=[m for _ in range(increment)] + for i in range(increment//2): + ms[i]-=m//imbalance + for i in range(increment//2,increment): + ms[i]+=m//imbalance + np.random.shuffle(ms) + ms = np.array(ms,dtype=int) + logging.info("Constructing exmplars, Imbalance...({} or {} per classes)".format(m-m//imbalance,(m+m//imbalance))) + elif imbalance<1: + ''' + exp type + ''' + ms = np.array([imbalance**i for i in range(increment)]) + ms = ms/ms.sum() + tot = m*increment + ms = (tot*ms).astype(int) + np.random.shuffle(ms) + + else: + assert 0, "not implemented yet" + logging.info("ms {}".format(ms)) + # Construct exemplars for new classes and calculate the means + for class_idx in range(self._known_classes, self._total_classes): + data, targets, class_dset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + class_loader = DataLoader(class_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4,pin_memory=True) + + vectors, _ = self._extract_vectors(class_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + + # Select + selected_exemplars = [] + exemplar_vectors = [] + for k in range(1, ms[class_idx-self._known_classes]+1): + S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + + selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference + exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference + + vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection + data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection + + selected_exemplars = np.array(selected_exemplars) + if len(selected_exemplars)==0: + continue + exemplar_targets = np.full(ms[class_idx-self._known_classes], class_idx) + self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ + else selected_exemplars + self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \ + len(self._targets_memory) != 0 else exemplar_targets + + # Exemplar mean + exemplar_dset = data_manager.get_dataset([], source='train', mode='test', + appendent=(selected_exemplars, exemplar_targets)) + exemplar_loader = DataLoader(exemplar_dset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(exemplar_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + _class_means[class_idx, :] = mean + # _class_means[class_idx,:] = class_mean + + self._class_means = _class_means + diff --git a/models/bic.py b/models/bic.py new file mode 100644 index 0000000000000000000000000000000000000000..c57aba6c7a106e38aa6c3bd0aeeffee969f3a058 --- /dev/null +++ b/models/bic.py @@ -0,0 +1,206 @@ +import logging +import numpy as np +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import IncrementalNetWithBias + + +epochs = 170 +lrate = 0.1 +milestones = [60, 100, 140] +lrate_decay = 0.1 +batch_size = 128 +split_ratio = 0.1 +T = 2 +weight_decay = 2e-4 +num_workers = 8 + + +class BiC(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = IncrementalNetWithBias( + args, False, bias_correction=True + ) + self._class_means = None + + def after_task(self): + self._old_network = self._network.copy().freeze() + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + if self._cur_task >= 1: + train_dset, val_dset = data_manager.get_dataset_with_split( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + val_samples_per_class=int( + split_ratio * self._memory_size / self._known_classes + ), + ) + self.val_loader = DataLoader( + val_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + logging.info( + "Stage1 dset: {}, Stage2 dset: {}".format( + len(train_dset), len(val_dset) + ) + ) + self.lamda = self._known_classes / self._total_classes + logging.info("Lambda: {:.3f}".format(self.lamda)) + else: + train_dset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + test_dset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + + self.train_loader = DataLoader( + train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + self.test_loader = DataLoader( + test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + self._log_bias_params() + self._stage1_training(self.train_loader, self.test_loader) + if self._cur_task >= 1: + self._stage2_bias_correction(self.val_loader, self.test_loader) + + self.build_rehearsal_memory(data_manager, self.samples_per_class) + + if len(self._multiple_gpus) > 1: + self._network = self._network.module + self._log_bias_params() + + def _run(self, train_loader, test_loader, optimizer, scheduler, stage): + for epoch in range(1, epochs + 1): + self._network.train() + losses = 0.0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + if stage == "training": + clf_loss = F.cross_entropy(logits, targets) + if self._old_network is not None: + old_logits = self._old_network(inputs)["logits"].detach() + hat_pai_k = F.softmax(old_logits / T, dim=1) + log_pai_k = F.log_softmax( + logits[:, : self._known_classes] / T, dim=1 + ) + distill_loss = -torch.mean( + torch.sum(hat_pai_k * log_pai_k, dim=1) + ) + loss = distill_loss * self.lamda + clf_loss * (1 - self.lamda) + else: + loss = clf_loss + elif stage == "bias_correction": + loss = F.cross_entropy(torch.softmax(logits, dim=1), targets) + else: + raise NotImplementedError() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + scheduler.step() + train_acc = self._compute_accuracy(self._network, train_loader) + test_acc = self._compute_accuracy(self._network, test_loader) + info = "{} => Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}".format( + stage, + self._cur_task, + epoch, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + logging.info(info) + + def _stage1_training(self, train_loader, test_loader): + """ + if self._cur_task == 0: + loaded_dict = torch.load('./dict_0.pkl') + self._network.load_state_dict(loaded_dict['model_state_dict']) + self._network.to(self._device) + return + """ + + ignored_params = list(map(id, self._network.bias_layers.parameters())) + base_params = filter( + lambda p: id(p) not in ignored_params, self._network.parameters() + ) + network_params = [ + {"params": base_params, "lr": lrate, "weight_decay": weight_decay}, + { + "params": self._network.bias_layers.parameters(), + "lr": 0, + "weight_decay": 0, + }, + ] + optimizer = optim.SGD( + network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._network.to(self._device) + if self._old_network is not None: + self._old_network.to(self._device) + + self._run(train_loader, test_loader, optimizer, scheduler, stage="training") + + def _stage2_bias_correction(self, val_loader, test_loader): + if isinstance(self._network, nn.DataParallel): + self._network = self._network.module + network_params = [ + { + "params": self._network.bias_layers[-1].parameters(), + "lr": lrate, + "weight_decay": weight_decay, + } + ] + optimizer = optim.SGD( + network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._network.to(self._device) + + self._run( + val_loader, test_loader, optimizer, scheduler, stage="bias_correction" + ) + + def _log_bias_params(self): + logging.info("Parameters of bias layer:") + params = self._network.get_bias_params() + for i, param in enumerate(params): + logging.info("{} => {:.3f}, {:.3f}".format(i, param[0], param[1])) diff --git a/models/coil.py b/models/coil.py new file mode 100644 index 0000000000000000000000000000000000000000..b000510f23ee033f60a2f9a58a73bae680ad8c63 --- /dev/null +++ b/models/coil.py @@ -0,0 +1,332 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import ( + IncrementalNet, + CosineIncrementalNet, + SimpleCosineIncrementalNet, +) +from utils.toolkit import target2onehot, tensor2numpy +import ot +from torch import nn +import copy + +EPSILON = 1e-8 + +epochs = 100 +lrate = 0.1 +milestones = [40, 80] +lrate_decay = 0.1 +batch_size = 32 +memory_size = 2000 +T = 2 + + +class COIL(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = SimpleCosineIncrementalNet(args, False) + self.data_manager = None + self.nextperiod_initialization = None + self.sinkhorn_reg = args["sinkhorn"] + self.calibration_term = args["calibration_term"] + self.args = args + + def after_task(self): + self.nextperiod_initialization = self.solving_ot() + self._old_network = self._network.copy().freeze() + self._known_classes = self._total_classes + + def solving_ot(self): + with torch.no_grad(): + if self._total_classes == self.data_manager.get_total_classnum(): + print("training over, no more ot solving") + return None + each_time_class_num = self.data_manager.get_task_size(1) + self._extract_class_means( + self.data_manager, 0, self._total_classes + each_time_class_num + ) + former_class_means = torch.tensor( + self._ot_prototype_means[: self._total_classes] + ) + next_period_class_means = torch.tensor( + self._ot_prototype_means[ + self._total_classes : self._total_classes + each_time_class_num + ] + ) + Q_cost_matrix = torch.cdist( + former_class_means, next_period_class_means, p=self.args["norm_term"] + ) + # solving ot + _mu1_vec = ( + torch.ones(len(former_class_means)) / len(former_class_means) * 1.0 + ) + _mu2_vec = ( + torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0 + ) + T = ot.sinkhorn(_mu1_vec, _mu2_vec, Q_cost_matrix, self.sinkhorn_reg) + T = torch.tensor(T).float().cuda() + transformed_hat_W = torch.mm( + T.T, F.normalize(self._network.fc.weight, p=2, dim=1) + ) + oldnorm = torch.norm(self._network.fc.weight, p=2, dim=1) + newnorm = torch.norm( + transformed_hat_W * len(former_class_means), p=2, dim=1 + ) + meannew = torch.mean(newnorm) + meanold = torch.mean(oldnorm) + gamma = meanold / meannew + self.calibration_term = gamma + self._ot_new_branch = ( + transformed_hat_W * len(former_class_means) * self.calibration_term + ) + return transformed_hat_W * len(former_class_means) * self.calibration_term + + def solving_ot_to_old(self): + current_class_num = self.data_manager.get_task_size(self._cur_task) + self._extract_class_means_with_memory( + self.data_manager, self._known_classes, self._total_classes + ) + former_class_means = torch.tensor( + self._ot_prototype_means[: self._known_classes] + ) + next_period_class_means = torch.tensor( + self._ot_prototype_means[self._known_classes : self._total_classes] + ) + Q_cost_matrix = ( + torch.cdist( + next_period_class_means, former_class_means, p=self.args["norm_term"] + ) + + EPSILON + ) # in case of numerical err + _mu1_vec = torch.ones(len(former_class_means)) / len(former_class_means) * 1.0 + _mu2_vec = ( + torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0 + ) + T = ot.sinkhorn(_mu2_vec, _mu1_vec, Q_cost_matrix, self.sinkhorn_reg) + T = torch.tensor(T).float().cuda() + transformed_hat_W = torch.mm( + T.T, + F.normalize(self._network.fc.weight[-current_class_num:, :], p=2, dim=1), + ) + return transformed_hat_W * len(former_class_means) * self.calibration_term + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + + self._network.update_fc(self._total_classes, self.nextperiod_initialization) + self.data_manager = data_manager + + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + self.lamda = self._known_classes / self._total_classes + # Loader + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=4 + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + + self._train(self.train_loader, self.test_loader) + + if self.args['fixed_memory']: + examplar_size = self.args["memory_per_class"] + else: + examplar_size = memory_size // self._total_classes + self._reduce_exemplar(data_manager, examplar_size) + self._construct_exemplar(data_manager, examplar_size) + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._old_network is not None: + self._old_network.to(self._device) + optimizer = optim.SGD( + self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=5e-4 + ) # 1e-5 + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + weight_ot_init = max(1.0 - (epoch / 2) ** 2, 0) + weight_ot_co_tuning = (epoch / epochs) ** 2.0 + + self._network.train() + losses = 0.0 + correct, total = 0, 0 + + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + output = self._network(inputs) + logits = output["logits"] + onehots = target2onehot(targets, self._total_classes) + + clf_loss = F.cross_entropy(logits, targets) + if self._old_network is not None: + + old_logits = self._old_network(inputs)["logits"].detach() + hat_pai_k = F.softmax(old_logits / T, dim=1) + log_pai_k = F.log_softmax( + logits[:, : self._known_classes] / T, dim=1 + ) + distill_loss = -torch.mean(torch.sum(hat_pai_k * log_pai_k, dim=1)) + + if epoch < 1: + features = F.normalize(output["features"], p=2, dim=1) + current_logit_new = F.log_softmax( + logits[:, self._known_classes :] / T, dim=1 + ) + new_logit_by_wnew_init_by_ot = F.linear( + features, F.normalize(self._ot_new_branch, p=2, dim=1) + ) + new_logit_by_wnew_init_by_ot = F.softmax( + new_logit_by_wnew_init_by_ot / T, dim=1 + ) + new_branch_distill_loss = -torch.mean( + torch.sum( + current_logit_new * new_logit_by_wnew_init_by_ot, dim=1 + ) + ) + + loss = ( + distill_loss * self.lamda + + clf_loss * (1 - self.lamda) + + 0.001 * (weight_ot_init * new_branch_distill_loss) + ) + else: + features = F.normalize(output["features"], p=2, dim=1) + if i % 30 == 0: + with torch.no_grad(): + self._ot_old_branch = self.solving_ot_to_old() + old_logit_by_wold_init_by_ot = F.linear( + features, F.normalize(self._ot_old_branch, p=2, dim=1) + ) + old_logit_by_wold_init_by_ot = F.log_softmax( + old_logit_by_wold_init_by_ot / T, dim=1 + ) + old_branch_distill_loss = -torch.mean( + torch.sum(hat_pai_k * old_logit_by_wold_init_by_ot, dim=1) + ) + loss = ( + distill_loss * self.lamda + + clf_loss * (1 - self.lamda) + + self.args["reg_term"] + * (weight_ot_co_tuning * old_branch_distill_loss) + ) + else: + loss = clf_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + prog_bar.set_description(info) + + logging.info(info) + + def _extract_class_means(self, data_manager, low, high): + self._ot_prototype_means = np.zeros( + (data_manager.get_total_classnum(), self._network.feature_dim) + ) + with torch.no_grad(): + for class_idx in range(low, high): + data, targets, idx_dataset = data_manager.get_dataset( + np.arange(class_idx, class_idx + 1), + source="train", + mode="test", + ret_data=True, + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + class_mean = class_mean / (np.linalg.norm(class_mean)) + self._ot_prototype_means[class_idx, :] = class_mean + self._network.train() + + def _extract_class_means_with_memory(self, data_manager, low, high): + + self._ot_prototype_means = np.zeros( + (data_manager.get_total_classnum(), self._network.feature_dim) + ) + memoryx, memoryy = self._data_memory, self._targets_memory + with torch.no_grad(): + for class_idx in range(0, low): + idxes = np.where( + np.logical_and(memoryy >= class_idx, memoryy < class_idx + 1) + )[0] + data, targets = memoryx[idxes], memoryy[idxes] + # idx_dataset=TensorDataset(data,targets) + # idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + _, _, idx_dataset = data_manager.get_dataset( + [], + source="train", + appendent=(data, targets), + mode="test", + ret_data=True, + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + class_mean = class_mean / np.linalg.norm(class_mean) + self._ot_prototype_means[class_idx, :] = class_mean + + for class_idx in range(low, high): + data, targets, idx_dataset = data_manager.get_dataset( + np.arange(class_idx, class_idx + 1), + source="train", + mode="test", + ret_data=True, + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + class_mean = class_mean / np.linalg.norm(class_mean) + self._ot_prototype_means[class_idx, :] = class_mean + self._network.train() diff --git a/models/der.py b/models/der.py new file mode 100644 index 0000000000000000000000000000000000000000..3943b2e023c3660f3d67420a55f031157860705f --- /dev/null +++ b/models/der.py @@ -0,0 +1,230 @@ +# Please note that the current implementation of DER only contains the dynamic expansion process, since masking and pruning are not implemented by the source repo. +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import DERNet, IncrementalNet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy + +EPSILON = 1e-8 + +init_epoch = 100 +init_lr = 0.1 +init_milestones = [40, 60, 80] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 80 +lrate = 0.1 +milestones = [30, 50, 70] +lrate_decay = 0.1 +batch_size = 32 +weight_decay = 2e-4 +num_workers = 8 +T = 2 + + +class DER(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = DERNet(args, False) + + def after_task(self): + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + if self._cur_task > 0: + for i in range(self._cur_task): + for p in self._network.convnets[i].parameters(): + p.requires_grad = False + + logging.info("All params: {}".format(count_parameters(self._network))) + logging.info( + "Trainable params: {}".format(count_parameters(self._network, True)) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def train(self): + self._network.train() + if len(self._multiple_gpus) > 1 : + self._network_module_ptr = self._network.module + else: + self._network_module_ptr = self._network + self._network_module_ptr.convnets[-1].train() + if self._cur_task >= 1: + for i in range(self._cur_task): + self._network_module_ptr.convnets[i].eval() + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._cur_task == 0: + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + if len(self._multiple_gpus) > 1: + self._network.module.weight_align( + self._total_classes - self._known_classes + ) + else: + self._network.weight_align(self._total_classes - self._known_classes) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + self.train() + losses = 0.0 + losses_clf = 0.0 + losses_aux = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + outputs = self._network(inputs) + logits, aux_logits = outputs["logits"], outputs["aux_logits"] + loss_clf = F.cross_entropy(logits, targets) + aux_targets = targets.clone() + aux_targets = torch.where( + aux_targets - self._known_classes + 1 > 0, + aux_targets - self._known_classes + 1, + torch.tensor([0]).to(self._device), + ) + loss_aux = F.cross_entropy(aux_logits, aux_targets) + loss = loss_clf + loss_aux + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_aux += loss_aux.item() + losses_clf += loss_clf.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + losses_clf / len(train_loader), + losses_aux / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + losses_clf / len(train_loader), + losses_aux / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) diff --git a/models/ewc.py b/models/ewc.py new file mode 100644 index 0000000000000000000000000000000000000000..5493e09eca5449ef299025b1518932285883a911 --- /dev/null +++ b/models/ewc.py @@ -0,0 +1,254 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from models.podnet import pod_spatial_loss +from utils.inc_net import IncrementalNet +from utils.toolkit import target2onehot, tensor2numpy + +EPSILON = 1e-8 + +init_epoch = 200 +init_lr = 0.1 +init_milestones = [60, 120, 170] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 180 +lrate = 0.1 +milestones = [70, 120, 150] +lrate_decay = 0.1 +batch_size = 128 +weight_decay = 2e-4 +num_workers = 4 +T = 2 +lamda = 1000 +fishermax = 0.0001 + + +class EWC(BaseLearner): + def __init__(self, args): + super().__init__(args) + self.fisher = None + self._network = IncrementalNet(args, False) + + def after_task(self): + self._known_classes = self._total_classes + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + if self.fisher is None: + self.fisher = self.getFisherDiagonal(self.train_loader) + else: + alpha = self._known_classes / self._total_classes + new_finsher = self.getFisherDiagonal(self.train_loader) + for n, p in new_finsher.items(): + new_finsher[n][: len(self.fisher[n])] = ( + alpha * self.fisher[n] + + (1 - alpha) * new_finsher[n][: len(self.fisher[n])] + ) + self.fisher = new_finsher + self.mean = { + n: p.clone().detach() + for n, p in self._network.named_parameters() + if p.requires_grad + } + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + self._network.parameters(), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss_clf = F.cross_entropy( + logits[:, self._known_classes :], targets - self._known_classes + ) + loss_ewc = self.compute_ewc() + loss = loss_clf + lamda * loss_ewc + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + + def compute_ewc(self): + loss = 0 + if len(self._multiple_gpus) > 1: + for n, p in self._network.module.named_parameters(): + if n in self.fisher.keys(): + loss += ( + torch.sum( + (self.fisher[n]) + * (p[: len(self.mean[n])] - self.mean[n]).pow(2) + ) + / 2 + ) + else: + for n, p in self._network.named_parameters(): + if n in self.fisher.keys(): + loss += ( + torch.sum( + (self.fisher[n]) + * (p[: len(self.mean[n])] - self.mean[n]).pow(2) + ) + / 2 + ) + return loss + + def getFisherDiagonal(self, train_loader): + fisher = { + n: torch.zeros(p.shape).to(self._device) + for n, p in self._network.named_parameters() + if p.requires_grad + } + self._network.train() + optimizer = optim.SGD(self._network.parameters(), lr=lrate) + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + loss = torch.nn.functional.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + for n, p in self._network.named_parameters(): + if p.grad is not None: + fisher[n] += p.grad.pow(2).clone() + for n, p in fisher.items(): + fisher[n] = p / len(train_loader) + fisher[n] = torch.min(fisher[n], torch.tensor(fishermax)) + return fisher diff --git a/models/fetril.py b/models/fetril.py new file mode 100644 index 0000000000000000000000000000000000000000..9cdec8a43778a89ebdd29d34eb983dbbf9caeb0e --- /dev/null +++ b/models/fetril.py @@ -0,0 +1,227 @@ +''' + +results on CIFAR-100: + + | Reported Resnet18 | Reproduced Resnet32 +Protocols | Reported FC | Reported SVM | Reproduced FC | Reproduced SVM | + +T = 5 | 64.7 | 66.3 | 65.775 | 65.375 | + +T = 10 | 63.4 | 65.2 | 64.91 | 65.10 | + +T = 60 | 50.8 | 59.8 | 62.09 | 61.72 | + +''' + + +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader,Dataset +from models.base import BaseLearner +from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy +from sklearn.svm import LinearSVC +from torchvision import datasets, transforms +from utils.autoaugment import CIFAR10Policy,ImageNetPolicy +from utils.ops import Cutout + +EPSILON = 1e-8 + + +class FeTrIL(BaseLearner): + def __init__(self, args): + super().__init__(args) + self.args = args + self._network = IncrementalNet(args, False) + self._means = [] + self._svm_accs = [] + + + def after_task(self): + self._known_classes = self._total_classes + + def incremental_train(self, data_manager): + self.data_manager = data_manager + self.data_manager._train_trsf = [ + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63/255), + ImageNetPolicy(), + Cutout(n_holes=1, length=16), + ] + self._cur_task += 1 + + self._total_classes = self._known_classes + \ + data_manager.get_task_size(self._cur_task) + self._network.update_fc(self._total_classes) + self._network_module_ptr = self._network + logging.info( + 'Learning on {}-{}'.format(self._known_classes, self._total_classes)) + + if self._cur_task > 0: + for p in self._network.convnet.parameters(): + p.requires_grad = False + + logging.info('All params: {}'.format(count_parameters(self._network))) + logging.info('Trainable params: {}'.format( + count_parameters(self._network, True))) + + train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train', + mode='train', appendent=self._get_memory()) + self.train_loader = DataLoader( + train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source='test', mode='test') + self.test_loader = DataLoader( + test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"]) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if hasattr(self._network, "module"): + self._network_module_ptr = self._network.module + if self._cur_task == 0: + self._epoch_num = self.args["init_epochs"] + optimizer = optim.SGD(filter(lambda p: p.requires_grad, self._network.parameters( + )), momentum=0.9, lr=self.args["init_lr"], weight_decay=self.args["init_weight_decay"]) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args["init_epochs"]) + self._train_function(train_loader, test_loader, optimizer, scheduler) + self._compute_means() + self._build_feature_set() + else: + self._epoch_num = self.args["epochs"] + self._compute_means() + self._compute_relations() + self._build_feature_set() + + train_loader = DataLoader(self._feature_trainset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) + optimizer = optim.SGD(self._network_module_ptr.fc.parameters(),momentum=0.9,lr=self.args["lr"],weight_decay=self.args["weight_decay"]) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max = self.args["epochs"]) + + self._train_function(train_loader, test_loader, optimizer, scheduler) + self._train_svm(self._feature_trainset,self._feature_testset) + + + def _compute_means(self): + with torch.no_grad(): + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + class_mean = np.mean(vectors, axis=0) + self._means.append(class_mean) + + def _compute_relations(self): + old_means = np.array(self._means[:self._known_classes]) + new_means = np.array(self._means[self._known_classes:]) + self._relations=np.argmax((old_means/np.linalg.norm(old_means,axis=1)[:,None])@(new_means/np.linalg.norm(new_means,axis=1)[:,None]).T,axis=1)+self._known_classes + def _build_feature_set(self): + self.vectors_train = [] + self.labels_train = [] + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + self.vectors_train.append(vectors) + self.labels_train.append([class_idx]*len(vectors)) + for class_idx in range(0,self._known_classes): + new_idx = self._relations[class_idx] + self.vectors_train.append(self.vectors_train[new_idx-self._known_classes]-self._means[new_idx]+self._means[class_idx]) + self.labels_train.append([class_idx]*len(self.vectors_train[-1])) + + self.vectors_train = np.concatenate(self.vectors_train) + self.labels_train = np.concatenate(self.labels_train) + self._feature_trainset = FeatureDataset(self.vectors_train,self.labels_train) + + self.vectors_test = [] + self.labels_test = [] + for class_idx in range(0, self._total_classes): + data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='test', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + self.vectors_test.append(vectors) + self.labels_test.append([class_idx]*len(vectors)) + self.vectors_test = np.concatenate(self.vectors_test) + self.labels_test = np.concatenate(self.labels_test) + + self._feature_testset = FeatureDataset(self.vectors_test,self.labels_test) + + def _train_function(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self._epoch_num)) + for _, epoch in enumerate(prog_bar): + if self._cur_task == 0: + self._network.train() + else: + self._network.eval() + losses = 0. + correct, total = 0, 0 + for i, _, inputs, targets in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True), targets.to(self._device, non_blocking=True) + if self._cur_task ==0: + logits = self._network(inputs)['logits'] + else: + logits = self._network_module_ptr.fc(inputs)['logits'] + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy( + correct)*100 / total, decimals=2) + if epoch % 5 != 0: + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc) + else: + test_acc = self._compute_accuracy(self._network, test_loader) + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), train_acc, test_acc) + prog_bar.set_description(info) + logging.info(info) + def _train_svm(self,train_set,test_set): + train_features = train_set.features.numpy() + train_labels = train_set.labels.numpy() + test_features = test_set.features.numpy() + test_labels = test_set.labels.numpy() + train_features = train_features/np.linalg.norm(train_features,axis=1)[:,None] + test_features = test_features/np.linalg.norm(test_features,axis=1)[:,None] + svm_classifier = LinearSVC(random_state=42) + svm_classifier.fit(train_features,train_labels) + logging.info("svm train: acc: {}".format(np.around(svm_classifier.score(train_features,train_labels)*100,decimals=2))) + acc = svm_classifier.score(test_features,test_labels) + self._svm_accs.append(np.around(acc*100,decimals=2)) + logging.info("svm evaluation: acc_list: {}".format(self._svm_accs)) + +class FeatureDataset(Dataset): + def __init__(self, features, labels): + assert len(features) == len(labels), "Data size error!" + self.features = torch.from_numpy(features) + self.labels = torch.from_numpy(labels) + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + feature = self.features[idx] + label = self.labels[idx] + + return idx, feature, label diff --git a/models/finetune.py b/models/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..b85338f62e163b432f06dce8a16c5206330bf6af --- /dev/null +++ b/models/finetune.py @@ -0,0 +1,206 @@ +import logging +import numpy as np +import torch +import copy +from torch import nn +from torch.serialization import load +from tqdm import tqdm +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from utils.inc_net import IncrementalNet +from models.base import BaseLearner +from utils.toolkit import target2onehot, tensor2numpy + + +init_epoch = 100 +init_lr = 0.1 +init_milestones = [40, 60, 80] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 80 +lrate = 0.1 +milestones = [40, 70] +lrate_decay = 0.1 +batch_size = 32 +weight_decay = 2e-4 +num_workers = 8 + + +class Finetune(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = IncrementalNet(args, False) + + def after_task(self): + self._known_classes = self._total_classes + + def save_checkpoint(self, test_acc): + assert self.args['model_name'] == 'finetune' + checkpoint_name = f"models/finetune/{self.args['csv_name']}" + _checkpoint_cpu = copy.deepcopy(self._network) + if isinstance(_checkpoint_cpu, nn.DataParallel): + _checkpoint_cpu = _checkpoint_cpu.module + _checkpoint_cpu.cpu() + save_dict = { + "tasks": self._cur_task, + "convnet": _checkpoint_cpu.convnet.state_dict(), + "fc":_checkpoint_cpu.fc.state_dict(), + "test_acc": test_acc + } + torch.save(save_dict, "{}_{}.pkl".format(checkpoint_name, self._cur_task)) + + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + self._network.parameters(), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) # 1e-5 + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + fake_targets = targets - self._known_classes + loss_clf = F.cross_entropy( + logits[:, self._known_classes :], fake_targets + ) + + loss = loss_clf + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) diff --git a/models/foster.py b/models/foster.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c2ea78bcd4e3b47bcf8ec54d5ce84bef62b61a --- /dev/null +++ b/models/foster.py @@ -0,0 +1,435 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import FOSTERNet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy + +# Please refer to https://github.com/G-U-N/ECCV22-FOSTER for the full source code to reproduce foster. + +EPSILON = 1e-8 + + +class FOSTER(BaseLearner): + def __init__(self, args): + super().__init__(args) + self.args = args + self._network = FOSTERNet(args, False) + self._snet = None + self.beta1 = args["beta1"] + self.beta2 = args["beta2"] + self.per_cls_weights = None + self.is_teacher_wa = args["is_teacher_wa"] + self.is_student_wa = args["is_student_wa"] + self.lambda_okd = args["lambda_okd"] + self.wa_value = args["wa_value"] + self.oofc = args["oofc"].lower() + + def after_task(self): + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def load_checkpoint(self, filename): + checkpoint = torch.load(filename) + self._known_classes = len(checkpoint["classes"]) + self.class_list = np.array(checkpoint["classes"]) + self.label_list = checkpoint["label_list"] + self._network.update_fc(self._known_classes) + self._network.load_checkpoint(checkpoint["network"]) + self._network.to(self._device) + self._cur_task = 0 + + def save_checkpoint(self, filename): + self._network.cpu() + save_dict = { + "classes": self.data_manager.get_class_list(self._cur_task), + "network": { + "convnet": self._network.convnets[0].state_dict(), + "fc": self._network.fc.state_dict() + }, + "label_list": self.data_manager.get_label_list(self._cur_task), + } + torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task)) + def incremental_train(self, data_manager): + self.data_manager = data_manager + if hasattr(self.data_manager,'label_list') and hasattr(self,'label_list'): + self.data_manager.label_list = list(self.label_list.values()) + self.data_manager.label_list + self._cur_task += 1 + if self._cur_task > 1: + self._network = self._snet + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + self._network_module_ptr = self._network + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + if self._cur_task > 0: + for p in self._network.convnets[0].parameters(): + p.requires_grad = False + for p in self._network.oldfc.parameters(): + p.requires_grad = False + + logging.info("All params: {}".format(count_parameters(self._network))) + logging.info( + "Trainable params: {}".format(count_parameters(self._network, True)) + ) + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + + self.train_loader = DataLoader( + train_dataset, + batch_size=self.args["batch_size"], + shuffle=True, + num_workers=self.args["num_workers"], + pin_memory=True, + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=self.args["batch_size"], + shuffle=False, + num_workers=self.args["num_workers"], + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + #self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def train(self): + self._network_module_ptr.train() + self._network_module_ptr.convnets[-1].train() + if self._cur_task >= 1: + self._network_module_ptr.convnets[0].eval() + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if hasattr(self._network, "module"): + self._network_module_ptr = self._network.module + if self._cur_task == 0: + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + momentum=0.9, + lr=self.args["init_lr"], + weight_decay=self.args["init_weight_decay"], + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args["init_epochs"] + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + cls_num_list = [self.samples_old_class] * self._known_classes + [ + self.samples_new_class(i) + for i in range(self._known_classes, self._total_classes) + ] + effective_num = 1.0 - np.power(self.beta1, cls_num_list) + per_cls_weights = (1.0 - self.beta1) / np.array(effective_num) + per_cls_weights = ( + per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) + ) + + logging.info("per cls weights : {}".format(per_cls_weights)) + self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device) + + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + lr=self.args["lr"], + momentum=0.9, + weight_decay=self.args["weight_decay"], + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args["boosting_epochs"] + ) + if self.oofc == "az": + for i, p in enumerate(self._network_module_ptr.fc.parameters()): + if i == 0: + p.data[ + self._known_classes :, : self._network_module_ptr.out_dim + ] = torch.tensor(0.0) + elif self.oofc != "ft": + assert 0, "not implemented" + self._feature_boosting(train_loader, test_loader, optimizer, scheduler) + if self.is_teacher_wa: + self._network_module_ptr.weight_align( + self._known_classes, + self._total_classes - self._known_classes, + self.wa_value, + ) + else: + logging.info("do not weight align teacher!") + + cls_num_list = [self.samples_old_class] * self._known_classes + [ + self.samples_new_class(i) + for i in range(self._known_classes, self._total_classes) + ] + effective_num = 1.0 - np.power(self.beta2, cls_num_list) + per_cls_weights = (1.0 - self.beta2) / np.array(effective_num) + per_cls_weights = ( + per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) + ) + logging.info("per cls weights : {}".format(per_cls_weights)) + self.per_cls_weights = torch.FloatTensor(per_cls_weights).to(self._device) + self._feature_compression(train_loader, test_loader) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.args["init_epochs"])) + for _, epoch in enumerate(prog_bar): + self.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True + ), targets.to(self._device, non_blocking=True) + logits = self._network(inputs)["logits"] + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["init_epochs"], + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["init_epochs"], + losses / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + logging.info(info) + + def _feature_boosting(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.args["boosting_epochs"])) + for _, epoch in enumerate(prog_bar): + self.train() + losses = 0.0 + losses_clf = 0.0 + losses_fe = 0.0 + losses_kd = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True + ), targets.to(self._device, non_blocking=True) + outputs = self._network(inputs) + logits, fe_logits, old_logits = ( + outputs["logits"], + outputs["fe_logits"], + outputs["old_logits"].detach(), + ) + loss_clf = F.cross_entropy(logits / self.per_cls_weights, targets) + loss_fe = F.cross_entropy(fe_logits, targets) + loss_kd = self.lambda_okd * _KD_loss( + logits[:, : self._known_classes], old_logits, self.args["T"] + ) + loss = loss_clf + loss_fe + loss_kd + optimizer.zero_grad() + loss.backward() + if self.oofc == "az": + for i, p in enumerate(self._network_module_ptr.fc.parameters()): + if i == 0: + p.grad.data[ + self._known_classes :, + : self._network_module_ptr.out_dim, + ] = torch.tensor(0.0) + elif self.oofc != "ft": + assert 0, "not implemented" + optimizer.step() + losses += loss.item() + losses_fe += loss_fe.item() + losses_clf += loss_clf.item() + losses_kd += ( + self._known_classes / self._total_classes + ) * loss_kd.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["boosting_epochs"], + losses / len(train_loader), + losses_clf / len(train_loader), + losses_fe / len(train_loader), + losses_kd / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fe {:.3f}, Loss_kd {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["boosting_epochs"], + losses / len(train_loader), + losses_clf / len(train_loader), + losses_fe / len(train_loader), + losses_kd / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + + def _feature_compression(self, train_loader, test_loader): + self._snet = FOSTERNet(self.args, False) + self._snet.update_fc(self._total_classes) + if len(self._multiple_gpus) > 1: + self._snet = nn.DataParallel(self._snet, self._multiple_gpus) + if hasattr(self._snet, "module"): + self._snet_module_ptr = self._snet.module + else: + self._snet_module_ptr = self._snet + self._snet.to(self._device) + self._snet_module_ptr.convnets[0].load_state_dict( + self._network_module_ptr.convnets[0].state_dict() + ) + self._snet_module_ptr.copy_fc(self._network_module_ptr.oldfc) + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._snet.parameters()), + lr=self.args["lr"], + momentum=0.9, + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args["compression_epochs"] + ) + self._network.eval() + prog_bar = tqdm(range(self.args["compression_epochs"])) + for _, epoch in enumerate(prog_bar): + self._snet.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True + ), targets.to(self._device, non_blocking=True) + dark_logits = self._snet(inputs)["logits"] + with torch.no_grad(): + outputs = self._network(inputs) + logits, old_logits, fe_logits = ( + outputs["logits"], + outputs["old_logits"], + outputs["fe_logits"], + ) + loss_dark = self.BKD(dark_logits, logits, self.args["T"]) + loss = loss_dark + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + _, preds = torch.max(dark_logits[: targets.shape[0]], dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._snet, test_loader) + info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["compression_epochs"], + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "SNet: Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args["compression_epochs"], + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + if len(self._multiple_gpus) > 1: + self._snet = self._snet.module + if self.is_student_wa: + self._snet.weight_align( + self._known_classes, + self._total_classes - self._known_classes, + self.wa_value, + ) + else: + logging.info("do not weight align student!") + if self._cur_task > 1: + self._network = self._snet + self._snet.eval() + y_pred, y_true = [], [] + for _, (_, inputs, targets) in enumerate(test_loader): + inputs = inputs.to(self._device, non_blocking=True) + with torch.no_grad(): + outputs = self._snet(inputs)["logits"] + predicts = torch.topk( + outputs, k=self.topk, dim=1, largest=True, sorted=True + )[1] + y_pred.append(predicts.cpu().numpy()) + y_true.append(targets.cpu().numpy()) + y_pred = np.concatenate(y_pred) + y_true = np.concatenate(y_true) + cnn_accy = self._evaluate(y_pred, y_true) + logging.info("darknet eval: ") + logging.info("CNN top1 curve: {}".format(cnn_accy["top1"])) + logging.info("CNN top5 curve: {}".format(cnn_accy["top5"])) + + @property + def samples_old_class(self): + if self._fixed_memory: + return self._memory_per_class + else: + assert self._total_classes != 0, "Total classes is 0" + return self._memory_size // self._known_classes + + def samples_new_class(self, index): + if self.args["dataset"] == "cifar100": + return 500 + else: + return self.data_manager.getlen(index) + + def BKD(self, pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + soft = soft * self.per_cls_weights + soft = soft / soft.sum(1)[:, None] + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] + + +def _KD_loss(pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] diff --git a/models/gem.py b/models/gem.py new file mode 100644 index 0000000000000000000000000000000000000000..6d42769ec12ac43a8a7d2fa1590eedbab89c6d39 --- /dev/null +++ b/models/gem.py @@ -0,0 +1,304 @@ +import logging +import numpy as np +from torch._C import device +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import IncrementalNet +from utils.inc_net import CosineIncrementalNet +from utils.toolkit import target2onehot, tensor2numpy +try: + from quadprog import solve_qp +except: + pass + + +EPSILON = 1e-8 + + +init_epoch = 1 +init_lr = 0.1 +init_milestones = [40, 60, 80] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 1 +lrate = 0.1 +milestones = [20, 40, 60] +lrate_decay = 0.1 +batch_size = 16 +weight_decay = 2e-4 +num_workers = 4 + + +class GEM(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = IncrementalNet(args, False) + self.previous_data = None + self.previous_label = None + + def after_task(self): + self._old_network = self._network.copy().freeze() + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + if self._cur_task > 0: + previous_dataset = data_manager.get_dataset( + [], source="train", mode="train", appendent=self._get_memory() + ) + + self.previous_data = [] + self.previous_label = [] + for i in previous_dataset: + _, data_, label_ = i + self.previous_data.append(data_) + self.previous_label.append(label_) + self.previous_data = torch.stack(self.previous_data) + self.previous_label = torch.tensor(self.previous_label) + # Procedure + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._old_network is not None: + self._old_network.to(self._device) + + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + self._network.parameters(), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) # 1e-5 + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(epochs)) + grad_numels = [] + for params in self._network.parameters(): + grad_numels.append(params.data.numel()) + G = torch.zeros((sum(grad_numels), self._cur_task + 1)).to(self._device) + + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + incremental_step = self._total_classes - self._known_classes + for k in range(0, self._cur_task): + optimizer.zero_grad() + mask = torch.where( + (self.previous_label >= k * incremental_step) + & (self.previous_label < (k + 1) * incremental_step) + )[0] + data_ = self.previous_data[mask].to(self._device) + label_ = self.previous_label[mask].to(self._device) + pred_ = self._network(data_)["logits"] + pred_[:, : k * incremental_step].data.fill_(-10e10) + pred_[:, (k + 1) * incremental_step :].data.fill_(-10e10) + loss_ = F.cross_entropy(pred_, label_) + loss_.backward() + + j = 0 + for params in self._network.parameters(): + if params is not None: + if j == 0: + stpt = 0 + else: + stpt = sum(grad_numels[:j]) + + endpt = sum(grad_numels[: j + 1]) + G[stpt:endpt, k].data.copy_(params.grad.data.view(-1)) + j += 1 + + optimizer.zero_grad() + + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + logits[:, : self._known_classes].data.fill_(-10e10) + loss_clf = F.cross_entropy(logits, targets) + + loss = loss_clf + + optimizer.zero_grad() + loss.backward() + + j = 0 + for params in self._network.parameters(): + if params is not None: + if j == 0: + stpt = 0 + else: + stpt = sum(grad_numels[:j]) + + endpt = sum(grad_numels[: j + 1]) + G[stpt:endpt, self._cur_task].data.copy_( + params.grad.data.view(-1) + ) + j += 1 + + dotprod = torch.mm( + G[:, self._cur_task].unsqueeze(0), G[:, : self._cur_task] + ) + + if (dotprod < 0).sum() > 0: + + old_grad = G[:, : self._cur_task].cpu().t().double().numpy() + cur_grad = G[:, self._cur_task].cpu().contiguous().double().numpy() + + C = old_grad @ old_grad.T + p = old_grad @ cur_grad + A = np.eye(old_grad.shape[0]) + b = np.zeros(old_grad.shape[0]) + + v = solve_qp(C, -p, A, b)[0] + + new_grad = old_grad.T @ v + cur_grad + new_grad = torch.tensor(new_grad).float().to(self._device) + + new_dotprod = torch.mm( + new_grad.unsqueeze(0), G[:, : self._cur_task] + ) + if (new_dotprod < -0.01).sum() > 0: + assert 0 + j = 0 + for params in self._network.parameters(): + if params is not None: + if j == 0: + stpt = 0 + else: + stpt = sum(grad_numels[:j]) + + endpt = sum(grad_numels[: j + 1]) + params.grad.data.copy_( + new_grad[stpt:endpt] + .contiguous() + .view(params.grad.data.size()) + ) + j += 1 + + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) diff --git a/models/icarl.py b/models/icarl.py new file mode 100644 index 0000000000000000000000000000000000000000..cd400b0b7d2a62ed58603eea5bf58f2a60c9545a --- /dev/null +++ b/models/icarl.py @@ -0,0 +1,205 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import IncrementalNet +from utils.inc_net import CosineIncrementalNet +from utils.toolkit import target2onehot, tensor2numpy + +EPSILON = 1e-8 + +init_epoch = 100 +init_lr = 0.1 +init_milestones = [40, 60, 80] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 80 +lrate = 0.1 +milestones = [40, 60] +lrate_decay = 0.1 +batch_size = 32 +weight_decay = 2e-4 +num_workers = 8 +T = 2 + + +class iCaRL(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = IncrementalNet(args, False) + + def after_task(self): + self._old_network = self._network.copy().freeze() + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._old_network is not None: + self._old_network.to(self._device) + + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + self._network.parameters(), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) # 1e-5 + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss_clf = F.cross_entropy(logits, targets) + loss_kd = _KD_loss( + logits[:, : self._known_classes], + self._old_network(inputs)["logits"], + T, + ) + + loss = loss_clf + loss_kd + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + + +def _KD_loss(pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] diff --git a/models/il2a.py b/models/il2a.py new file mode 100644 index 0000000000000000000000000000000000000000..45a9ded99ecaf32ad5f51cca7cd691a2f9774021 --- /dev/null +++ b/models/il2a.py @@ -0,0 +1,250 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader,Dataset +from models.base import BaseLearner +from utils.inc_net import CosineIncrementalNet, FOSTERNet, IL2ANet, IncrementalNet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy + +EPSILON = 1e-8 + + +class IL2A(BaseLearner): + def __init__(self, args): + super().__init__(args) + self.args = args + self._network = IL2ANet(args, False) + self._protos = [] + self._covs = [] + + + def after_task(self): + self._known_classes = self._total_classes + self._old_network = self._network.copy().freeze() + if hasattr(self._old_network,"module"): + self.old_network_module_ptr = self._old_network.module + else: + self.old_network_module_ptr = self._old_network + #self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"])) + def incremental_train(self, data_manager): + self.data_manager = data_manager + self._cur_task += 1 + + task_size = self.data_manager.get_task_size(self._cur_task) + self._total_classes = self._known_classes + task_size + self._network.update_fc(self._known_classes,self._total_classes,int((task_size-1)*task_size/2)) + self._network_module_ptr = self._network + logging.info( + 'Learning on {}-{}'.format(self._known_classes, self._total_classes)) + + + logging.info('All params: {}'.format(count_parameters(self._network))) + logging.info('Trainable params: {}'.format( + count_parameters(self._network, True))) + + train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train', + mode='train', appendent=self._get_memory()) + self.train_loader = DataLoader( + train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source='test', mode='test') + self.test_loader = DataLoader( + test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"]) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + + def _train(self, train_loader, test_loader): + + resume = False + if self._cur_task in []: + self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"]) + resume = True + self._network.to(self._device) + if hasattr(self._network, "module"): + self._network_module_ptr = self._network.module + if not resume: + self._epoch_num = self.args["epochs"] + optimizer = torch.optim.Adam(self._network.parameters(), lr=self.args["lr"], weight_decay=self.args["weight_decay"]) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"]) + self._train_function(train_loader, test_loader, optimizer, scheduler) + self._build_protos() + + + def _build_protos(self): + with torch.no_grad(): + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + class_mean = np.mean(vectors, axis=0) + self._protos.append(class_mean) + cov = np.cov(vectors.T) + self._covs.append(cov) + + def _train_function(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self._epoch_num)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0. + losses_clf, losses_fkd, losses_proto = 0., 0., 0. + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True), targets.to(self._device, non_blocking=True) + inputs,targets = self._class_aug(inputs,targets) + logits, loss_clf, loss_fkd, loss_proto = self._compute_il2a_loss(inputs,targets) + loss = loss_clf + loss_fkd + loss_proto + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_clf += loss_clf.item() + losses_fkd += loss_fkd.item() + losses_proto += loss_proto.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy( + correct)*100 / total, decimals=2) + if epoch % 5 != 0: + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc) + else: + test_acc = self._compute_accuracy(self._network, test_loader) + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc) + prog_bar.set_description(info) + logging.info(info) + + def _compute_il2a_loss(self,inputs, targets): + logits = self._network(inputs)["logits"] + loss_clf = F.cross_entropy(logits/self.args["temp"], targets) + + if self._cur_task == 0: + return logits, loss_clf, torch.tensor(0.), torch.tensor(0.) + + features = self._network_module_ptr.extract_vector(inputs) + features_old = self.old_network_module_ptr.extract_vector(inputs) + loss_fkd = self.args["lambda_fkd"] * torch.dist(features, features_old, 2) + + index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True) + + proto_features = np.array(self._protos)[index] + proto_targets = index + proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True) + proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True) + + proto_logits = self._network_module_ptr.fc(proto_features)["logits"][:,:self._total_classes] + + + proto_logits = self._semantic_aug(proto_logits,proto_targets,self.args["ratio"]) + + loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets) + return logits, loss_clf, loss_fkd, loss_proto + + + def _semantic_aug(self,proto_logits,proto_targets,ratio): + # weight_fc = self._network_module_ptr.fc.weight.data[:self._total_classes] # don't use it ! data is not involved in back propagation + weight_fc = self._network_module_ptr.fc.weight[:self._total_classes] + N,C,D = self.args["batch_size"], self._total_classes, weight_fc.shape[1] + + N_weight = weight_fc.expand(N,C,D) # NCD + N_target_weight = torch.gather(N_weight, 1, proto_targets[:,None,None].expand(N,C,D)) # NCD + N_v = N_weight-N_target_weight + N_cov = torch.from_numpy(np.array(self._covs))[proto_targets].float().to(self._device) # NDD + + proto_logits = proto_logits + ratio/2* torch.diagonal(N_v @ N_cov @ N_v.permute(0,2,1),dim1=1,dim2=2) # NC + + return proto_logits + + + + + def _class_aug(self,inputs,targets,alpha=20., mix_time=4): + + mixup_inputs = [] + mixup_targets = [] + for _ in range(mix_time): + index = torch.randperm(inputs.shape[0]) + perm_inputs = inputs[index] + perm_targets = targets[index] + mask = perm_targets!= targets + + select_inputs = inputs[mask] + select_targets = targets[mask] + perm_inputs = perm_inputs[mask] + perm_targets = perm_targets[mask] + + lams = np.random.beta(alpha,alpha,sum(mask)) + lams = np.where((lams<0.4)|(lams>0.6),0.5,lams) + lams = torch.from_numpy(lams).to(self._device)[:,None,None,None].float() + + + mixup_inputs.append(lams*select_inputs+(1-lams)*perm_inputs) + mixup_targets.append(self._map_targets(select_targets,perm_targets)) + mixup_inputs = torch.cat(mixup_inputs,dim=0) + mixup_targets = torch.cat(mixup_targets,dim=0) + + inputs = torch.cat([inputs,mixup_inputs],dim=0) + targets = torch.cat([targets,mixup_targets],dim=0) + return inputs,targets + + def _map_targets(self,select_targets,perm_targets): + assert (select_targets != perm_targets).all() + large_targets = torch.max(select_targets,perm_targets)-self._known_classes + small_targets = torch.min(select_targets,perm_targets)-self._known_classes + + mixup_targets = large_targets*(large_targets-1) // 2 + small_targets + self._total_classes + return mixup_targets + def _compute_accuracy(self, model, loader): + model.eval() + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = model(inputs)["logits"][:,:self._total_classes] + predicts = torch.max(outputs, dim=1)[1] + correct += (predicts.cpu() == targets).sum() + total += len(targets) + + return np.around(tensor2numpy(correct)*100 / total, decimals=2) + + def _eval_cnn(self, loader): + self._network.eval() + y_pred, y_true = [], [] + for _, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = self._network(inputs)["logits"][:,:self._total_classes] + predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] + y_pred.append(predicts.cpu().numpy()) + y_true.append(targets.cpu().numpy()) + + return np.concatenate(y_pred), np.concatenate(y_true) + + def eval_task(self, save_conf=False): + y_pred, y_true = self._eval_cnn(self.test_loader) + cnn_accy = self._evaluate(y_pred, y_true) + + if hasattr(self, '_class_means'): + y_pred, y_true = self._eval_nme(self.test_loader, self._class_means) + nme_accy = self._evaluate(y_pred, y_true) + elif hasattr(self, '_protos'): + y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None]) + nme_accy = self._evaluate(y_pred, y_true) + else: + nme_accy = None + + return cnn_accy, nme_accy diff --git a/models/lwf.py b/models/lwf.py new file mode 100644 index 0000000000000000000000000000000000000000..e618803cab955b43160ea91e7c631b602105cbb9 --- /dev/null +++ b/models/lwf.py @@ -0,0 +1,205 @@ +import logging +import numpy as np +import torch +from torch import nn +from torch.serialization import load +from tqdm import tqdm +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from utils.inc_net import IncrementalNet +from models.base import BaseLearner +from utils.toolkit import target2onehot, tensor2numpy + +init_epoch = 200 +init_lr = 0.1 +init_milestones = [60, 120, 160] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 250 +lrate = 0.1 +milestones = [60, 120, 180, 220] +lrate_decay = 0.1 +batch_size = 128 +weight_decay = 2e-4 +num_workers = 8 +T = 2 +lamda = 3 + + +class LwF(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = IncrementalNet(args, False) + + def after_task(self): + self._old_network = self._network.copy().freeze() + self._known_classes = self._total_classes + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._old_network is not None: + self._old_network.to(self._device) + + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + self._network.parameters(), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + fake_targets = targets - self._known_classes + loss_clf = F.cross_entropy( + logits[:, self._known_classes :], fake_targets + ) + loss_kd = _KD_loss( + logits[:, : self._known_classes], + self._old_network(inputs)["logits"], + T, + ) + + loss = lamda * loss_kd + loss_clf + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + with torch.no_grad(): + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + + +def _KD_loss(pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] diff --git a/models/memo.py b/models/memo.py new file mode 100644 index 0000000000000000000000000000000000000000..a26e5ce29e79f6a57e4f4c2c737f1d262648a534 --- /dev/null +++ b/models/memo.py @@ -0,0 +1,337 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +import copy +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import AdaptiveNet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy + +num_workers=8 +EPSILON = 1e-8 +batch_size = 32 + +class MEMO(BaseLearner): + + def __init__(self, args): + super().__init__(args) + self.args = args + self._old_base = None + self._network = AdaptiveNet(args, True) + logging.info(f'>>> train generalized blocks:{self.args["train_base"]} train_adaptive:{self.args["train_adaptive"]}') + + def after_task(self): + self._known_classes = self._total_classes + if self._cur_task == 0: + if self.args['train_base']: + logging.info("Train Generalized Blocks...") + self._network.TaskAgnosticExtractor.train() + for param in self._network.TaskAgnosticExtractor.parameters(): + param.requires_grad = True + else: + logging.info("Fix Generalized Blocks...") + self._network.TaskAgnosticExtractor.eval() + for param in self._network.TaskAgnosticExtractor.parameters(): + param.requires_grad = False + + logging.info('Exemplar size: {}'.format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) + self._network.update_fc(self._total_classes) + + logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes)) + + if self._cur_task>0: + for i in range(self._cur_task): + for p in self._network.AdaptiveExtractors[i].parameters(): + if self.args['train_adaptive'] and i == self._cur_task: + p.requires_grad = True + else: + p.requires_grad = False + + logging.info('All params: {}'.format(count_parameters(self._network))) + logging.info('Trainable params: {}'.format(count_parameters(self._network, True))) + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source='train', + mode='train', + appendent=self._get_memory() + ) + self.train_loader = DataLoader( + train_dataset, + batch_size=self.args["batch_size"], + shuffle=True, + num_workers=num_workers + ) + + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), + source='test', + mode='test' + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=self.args["batch_size"], + shuffle=False, + num_workers=num_workers + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def set_network(self): + if len(self._multiple_gpus) > 1: + self._network = self._network.module + self._network.train() #All status from eval to train + if self.args['train_base']: + self._network.TaskAgnosticExtractor.train() + else: + self._network.TaskAgnosticExtractor.eval() + + # set adaptive extractor's status + self._network.AdaptiveExtractors[-1].train() + if self._cur_task >= 1: + for i in range(self._cur_task): + if self.args['train_adaptive']: + self._network.AdaptiveExtractors[i].train() + else: + self._network.AdaptiveExtractors[i].eval() + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._cur_task==0: + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + momentum=0.9, + lr=self.args["init_lr"], + weight_decay=self.args["init_weight_decay"] + ) + if self.args['scheduler'] == 'steplr': + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, + milestones=self.args['init_milestones'], + gamma=self.args['init_lr_decay'] + ) + elif self.args['scheduler'] == 'cosine': + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, + T_max=self.args['init_epoch'] + ) + else: + raise NotImplementedError + + if not self.args['skip']: + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + if isinstance(self._network, nn.DataParallel): + self._network = self._network.module + load_acc = self._network.load_checkpoint(self.args) + self._network.to(self._device) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + + cur_test_acc = self._compute_accuracy(self._network, self.test_loader) + logging.info(f"Loaded_Test_Acc:{load_acc} Cur_Test_Acc:{cur_test_acc}") + else: + optimizer = optim.SGD( + filter(lambda p: p.requires_grad, self._network.parameters()), + lr=self.args['lrate'], + momentum=0.9, + weight_decay=self.args['weight_decay'] + ) + if self.args['scheduler'] == 'steplr': + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, + milestones=self.args['milestones'], + gamma=self.args['lrate_decay'] + ) + elif self.args['scheduler'] == 'cosine': + assert self.args['t_max'] is not None + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, + T_max=self.args['t_max'] + ) + else: + raise NotImplementedError + self._update_representation(train_loader, test_loader, optimizer, scheduler) + if len(self._multiple_gpus) > 1: + self._network.module.weight_align(self._total_classes-self._known_classes) + else: + self._network.weight_align(self._total_classes-self._known_classes) + + + def _init_train(self,train_loader,test_loader,optimizer,scheduler): + prog_bar = tqdm(range(self.args["init_epoch"])) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0. + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)['logits'] + + loss=F.cross_entropy(logits,targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct)*100 / total, decimals=2) + if epoch%5==0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( + self._cur_task, epoch+1, self.args['init_epoch'], losses/len(train_loader), train_acc, test_acc) + else: + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format( + self._cur_task, epoch+1, self.args['init_epoch'], losses/len(train_loader), train_acc) + # prog_bar.set_description(info) + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.args["epochs"])) + for _, epoch in enumerate(prog_bar): + self.set_network() + losses = 0. + losses_clf=0. + losses_aux=0. + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + + outputs= self._network(inputs) + logits,aux_logits=outputs["logits"],outputs["aux_logits"] + loss_clf=F.cross_entropy(logits,targets) + aux_targets = targets.clone() + aux_targets=torch.where(aux_targets-self._known_classes+1.0>0, aux_targets-self._known_classes+1.0,torch.Tensor([.0]).to(self.args["device"][0])) + loss_aux=F.cross_entropy(aux_logits,aux_targets.long()) + loss=loss_clf+self.args['alpha_aux']*loss_aux + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_aux+=loss_aux.item() + losses_clf+=loss_clf.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct)*100 / total, decimals=2) + if epoch%5==0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( + self._cur_task, epoch+1, self.args["epochs"], losses/len(train_loader),losses_clf/len(train_loader),losses_aux/len(train_loader),train_acc, test_acc) + else: + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}'.format( + self._cur_task, epoch+1, self.args["epochs"], losses/len(train_loader), losses_clf/len(train_loader),losses_aux/len(train_loader),train_acc) + prog_bar.set_description(info) + logging.info(info) + + def save_checkpoint(self, test_acc): + assert self.args['model_name'] == 'finetune' + checkpoint_name = f"checkpoints/finetune_{self.args['csv_name']}" + _checkpoint_cpu = copy.deepcopy(self._network) + if isinstance(_checkpoint_cpu, nn.DataParallel): + _checkpoint_cpu = _checkpoint_cpu.module + _checkpoint_cpu.cpu() + save_dict = { + "tasks": self._cur_task, + "convnet": _checkpoint_cpu.convnet.state_dict(), + "fc":_checkpoint_cpu.fc.state_dict(), + "test_acc": test_acc + } + torch.save(save_dict, "{}_{}.pkl".format(checkpoint_name, self._cur_task)) + + def _construct_exemplar(self, data_manager, m): + logging.info("Constructing exemplars...({} per classes)".format(m)) + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = data_manager.get_dataset( + np.arange(class_idx, class_idx + 1), + source="train", + mode="test", + ret_data=True, + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + + # Select + selected_exemplars = [] + exemplar_vectors = [] # [n, feature_dim] + for k in range(1, m + 1): + S = np.sum( + exemplar_vectors, axis=0 + ) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + selected_exemplars.append( + np.array(data[i]) + ) # New object to avoid passing by inference + exemplar_vectors.append( + np.array(vectors[i]) + ) # New object to avoid passing by inference + + vectors = np.delete( + vectors, i, axis=0 + ) # Remove it to avoid duplicative selection + data = np.delete( + data, i, axis=0 + ) # Remove it to avoid duplicative selection + + if len(vectors) == 0: + break + # uniques = np.unique(selected_exemplars, axis=0) + # print('Unique elements: {}'.format(len(uniques))) + selected_exemplars = np.array(selected_exemplars) + # exemplar_targets = np.full(m, class_idx) + exemplar_targets = np.full(selected_exemplars.shape[0], class_idx) + self._data_memory = ( + np.concatenate((self._data_memory, selected_exemplars)) + if len(self._data_memory) != 0 + else selected_exemplars + ) + self._targets_memory = ( + np.concatenate((self._targets_memory, exemplar_targets)) + if len(self._targets_memory) != 0 + else exemplar_targets + ) + + # Exemplar mean + idx_dataset = data_manager.get_dataset( + [], + source="train", + mode="test", + appendent=(selected_exemplars, exemplar_targets), + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean \ No newline at end of file diff --git a/models/pa2s.py b/models/pa2s.py new file mode 100644 index 0000000000000000000000000000000000000000..2caa6ca579d9cd537377cd660b8c18605db7bf60 --- /dev/null +++ b/models/pa2s.py @@ -0,0 +1,216 @@ +import logging +import numpy as np +from tqdm import tqdm +import os +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader,Dataset +from models.base import BaseLearner +from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy + +EPSILON = 1e-8 + + +class PASS(BaseLearner): + def __init__(self, args): + super().__init__(args) + self.args = args + self._network = IncrementalNet(args, False) + self._protos = [] + self._radius = 0 + self._radiuses = [] + + + def after_task(self): + self._known_classes = self._total_classes + self._old_network = self._network.copy().freeze() + if hasattr(self._old_network,"module"): + self.old_network_module_ptr = self._old_network.module + else: + self.old_network_module_ptr = self._old_network + #self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"])) + def incremental_train(self, data_manager): + self.data_manager = data_manager + self._cur_task += 1 + + self._total_classes = self._known_classes + \ + data_manager.get_task_size(self._cur_task) + self._network.update_fc(self._total_classes*4) + self._network_module_ptr = self._network + logging.info( + 'Learning on {}-{}'.format(self._known_classes, self._total_classes)) + + + logging.info('All params: {}'.format(count_parameters(self._network))) + logging.info('Trainable params: {}'.format( + count_parameters(self._network, True))) + + train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train', + mode='train', appendent=self._get_memory()) + self.train_loader = DataLoader( + train_dataset, batch_size=self.args["batch_size"], shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source='test', mode='test') + self.test_loader = DataLoader( + test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"]) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + + def _train(self, train_loader, test_loader): + + resume = False + if self._cur_task in []: + self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"]) + resume = True + self._network.to(self._device) + if hasattr(self._network, "module"): + self._network_module_ptr = self._network.module + if not resume: + self._epoch_num = self.args["epochs"] + optimizer = torch.optim.Adam(self._network.parameters(), lr=self.args["lr"], weight_decay=self.args["weight_decay"]) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"]) + self._train_function(train_loader, test_loader, optimizer, scheduler) + self._build_protos() + + + def _build_protos(self): + with torch.no_grad(): + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + class_mean = np.mean(vectors, axis=0) + self._protos.append(class_mean) + cov = np.cov(vectors.T) + self._radiuses.append(np.trace(cov)/vectors.shape[1]) + self._radius = np.sqrt(np.mean(self._radiuses)) + + def _train_function(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self._epoch_num)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0. + losses_clf, losses_fkd, losses_proto = 0., 0., 0. + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True), targets.to(self._device, non_blocking=True) + inputs = torch.stack([torch.rot90(inputs, k, (2, 3)) for k in range(4)], 1) + inputs = inputs.view(-1, 3, 320, 320) + targets = torch.stack([targets * 4 + k for k in range(4)], 1).view(-1) + logits, loss_clf, loss_fkd, loss_proto = self._compute_pass_loss(inputs,targets) + loss = loss_clf + loss_fkd + loss_proto + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_clf += loss_clf.item() + losses_fkd += loss_fkd.item() + losses_proto += loss_proto.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy( + correct)*100 / total, decimals=2) + if epoch % 5 != 0: + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc) + else: + test_acc = self._compute_accuracy(self._network, test_loader) + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc) + prog_bar.set_description(info) + logging.info(info) + + def _compute_pass_loss(self,inputs, targets): + logits = self._network(inputs)["logits"] + loss_clf = F.cross_entropy(logits/self.args["temp"], targets) + + if self._cur_task == 0: + return logits, loss_clf, torch.tensor(0.), torch.tensor(0.) + + features = self._network_module_ptr.extract_vector(inputs) + features_old = self.old_network_module_ptr.extract_vector(inputs) + loss_fkd = self.args["lambda_fkd"] * torch.dist(features, features_old, 2) + + # index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True) + + index = np.random.choice(range(self._known_classes),size=self.args["batch_size"]*int(self._known_classes/(self._total_classes-self._known_classes)),replace=True) + # print(index) + # print(np.concatenate(self._protos)) + proto_features = np.array(self._protos)[index] + # print(proto_features) + proto_targets = 4*index + proto_features = proto_features + np.random.normal(0,1,proto_features.shape)*self._radius + proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True) + proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True) + + + proto_logits = self._network_module_ptr.fc(proto_features)["logits"] + loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets) + return logits, loss_clf, loss_fkd, loss_proto + + + + def _compute_accuracy(self, model, loader): + model.eval() + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = model(inputs)["logits"][:,::4] + predicts = torch.max(outputs, dim=1)[1] + correct += (predicts.cpu() == targets).sum() + total += len(targets) + + return np.around(tensor2numpy(correct)*100 / total, decimals=2) + + def _eval_cnn(self, loader): + self._network.eval() + y_pred, y_true = [], [] + for _, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = self._network(inputs)["logits"][:,::4] + predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] + y_pred.append(predicts.cpu().numpy()) + y_true.append(targets.cpu().numpy()) + + return np.concatenate(y_pred), np.concatenate(y_true) + + def eval_task(self, save_conf=True): + y_pred, y_true = self._eval_cnn(self.test_loader) + cnn_accy = self._evaluate(y_pred, y_true) + + if hasattr(self, '_class_means'): + y_pred, y_true = self._eval_nme(self.test_loader, self._class_means) + nme_accy = self._evaluate(y_pred, y_true) + elif hasattr(self, '_protos'): + y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None]) + nme_accy = self._evaluate(y_pred, y_true) + else: + nme_accy = None + if save_conf: + _pred = y_pred.T[0] + _pred_path = os.path.join(self.args['logfilename'], "pred.npy") + _target_path = os.path.join(self.args['logfilename'], "target.npy") + np.save(_pred_path, _pred) + np.save(_target_path, y_true) + + _save_dir = os.path.join(f"./results/{self.args['model_name']}/conf_matrix/{self.args['prefix']}") + os.makedirs(_save_dir, exist_ok=True) + _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv") + with open(_save_path, "a+") as f: + f.write(f"{self.args['model_name']},{_pred_path},{_target_path} \n") + return cnn_accy, nme_accy \ No newline at end of file diff --git a/models/podnet.py b/models/podnet.py new file mode 100644 index 0000000000000000000000000000000000000000..847090b5bfca6acc144fb45492b9dc69050b7995 --- /dev/null +++ b/models/podnet.py @@ -0,0 +1,324 @@ +import math +import logging +import numpy as np +import torch +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import CosineIncrementalNet +from utils.toolkit import tensor2numpy + +epochs = 100 +lrate = 0.1 +ft_epochs = 20 +ft_lrate = 0.005 +batch_size = 32 +lambda_c_base = 5 +lambda_f_base = 1 +nb_proxy = 10 +weight_decay = 5e-4 +num_workers = 4 + +""" +Distillation losses: POD-flat (lambda_f=1) + POD-spatial (lambda_c=5) +NME results are shown. +The reproduced results are not in line with the reported results. +Maybe I missed something... ++--------------------+--------------------+--------------------+--------------------+ +| Classifier | Steps | Reported (%) | Reproduced (%) | ++--------------------+--------------------+--------------------+--------------------+ +| Cosine (k=1) | 50 | 56.69 | 55.49 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-CE (k=10) | 50 | 59.86 | 55.69 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-NCA (k=10) | 50 | 61.40 | 56.50 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-CE (k=10) | 25 | ----- | 59.16 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-NCA (k=10) | 25 | 62.71 | 59.79 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-CE (k=10) | 10 | ----- | 62.59 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-NCA (k=10) | 10 | 64.03 | 62.81 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-CE (k=10) | 5 | ----- | 64.16 | ++--------------------+--------------------+--------------------+--------------------+ +| LSC-NCA (k=10) | 5 | 64.48 | 64.37 | ++--------------------+--------------------+--------------------+--------------------+ +""" + + +class PODNet(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = CosineIncrementalNet( + args, pretrained=False, nb_proxy=nb_proxy + ) + self._class_means = None + + def after_task(self): + self._old_network = self._network.copy().freeze() + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self.task_size = self._total_classes - self._known_classes + self._network.update_fc(self._total_classes, self._cur_task) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + train_dset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + test_dset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.train_loader = DataLoader( + train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + self.test_loader = DataLoader( + test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + self._train(data_manager, self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + + def _train(self, data_manager, train_loader, test_loader): + if self._cur_task == 0: + self.factor = 0 + else: + self.factor = math.sqrt( + self._total_classes / (self._total_classes - self._known_classes) + ) + logging.info("Adaptive factor: {}".format(self.factor)) + + self._network.to(self._device) + if self._old_network is not None: + self._old_network.to(self._device) + + if self._cur_task == 0: + network_params = self._network.parameters() + else: + ignored_params = list(map(id, self._network.fc.fc1.parameters())) + base_params = filter( + lambda p: id(p) not in ignored_params, self._network.parameters() + ) + network_params = [ + {"params": base_params, "lr": lrate, "weight_decay": weight_decay}, + { + "params": self._network.fc.fc1.parameters(), + "lr": 0, + "weight_decay": 0, + }, + ] + optimizer = optim.SGD( + network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=epochs + ) + self._run(train_loader, test_loader, optimizer, scheduler, epochs) + + if self._cur_task == 0: + return + logging.info( + "Finetune the network (classifier part) with the undersampled dataset!" + ) + if self._fixed_memory: + finetune_samples_per_class = self._memory_per_class + self._construct_exemplar_unified(data_manager, finetune_samples_per_class) + else: + finetune_samples_per_class = self._memory_size // self._known_classes + self._reduce_exemplar(data_manager, finetune_samples_per_class) + self._construct_exemplar(data_manager, finetune_samples_per_class) + + finetune_train_dataset = data_manager.get_dataset( + [], source="train", mode="train", appendent=self._get_memory() + ) + finetune_train_loader = DataLoader( + finetune_train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + ) + logging.info( + "The size of finetune dataset: {}".format(len(finetune_train_dataset)) + ) + + ignored_params = list(map(id, self._network.fc.fc1.parameters())) + base_params = filter( + lambda p: id(p) not in ignored_params, self._network.parameters() + ) + network_params = [ + {"params": base_params, "lr": ft_lrate, "weight_decay": weight_decay}, + {"params": self._network.fc.fc1.parameters(), "lr": 0, "weight_decay": 0}, + ] + optimizer = optim.SGD( + network_params, lr=ft_lrate, momentum=0.9, weight_decay=weight_decay + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=ft_epochs + ) + self._run(finetune_train_loader, test_loader, optimizer, scheduler, ft_epochs) + + if self._fixed_memory: + self._data_memory = self._data_memory[ + : -self._memory_per_class * self.task_size + ] + self._targets_memory = self._targets_memory[ + : -self._memory_per_class * self.task_size + ] + assert ( + len( + np.setdiff1d( + self._targets_memory, np.arange(0, self._known_classes) + ) + ) + == 0 + ), "Exemplar error!" + + def _run(self, train_loader, test_loader, optimizer, scheduler, epk): + for epoch in range(1, epk + 1): + self._network.train() + lsc_losses = 0.0 + spatial_losses = 0.0 + flat_losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + outputs = self._network(inputs) + logits = outputs["logits"] + features = outputs["features"] + fmaps = outputs["fmaps"] + lsc_loss = nca(logits, targets) + + spatial_loss = 0.0 + flat_loss = 0.0 + if self._old_network is not None: + with torch.no_grad(): + old_outputs = self._old_network(inputs) + old_features = old_outputs["features"] + old_fmaps = old_outputs["fmaps"] + flat_loss = ( + F.cosine_embedding_loss( + features, + old_features.detach(), + torch.ones(inputs.shape[0]).to(self._device), + ) + * self.factor + * lambda_f_base + ) + spatial_loss = ( + pod_spatial_loss(fmaps, old_fmaps) * self.factor * lambda_c_base + ) + + loss = lsc_loss + flat_loss + spatial_loss + optimizer.zero_grad() + loss.backward() + optimizer.step() + + lsc_losses += lsc_loss.item() + spatial_losses += ( + spatial_loss.item() if self._cur_task != 0 else spatial_loss + ) + flat_losses += flat_loss.item() if self._cur_task != 0 else flat_loss + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + if scheduler is not None: + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + test_acc = self._compute_accuracy(self._network, test_loader) + info1 = "Task {}, Epoch {}/{} (LR {:.5f}) => ".format( + self._cur_task, epoch, epk, optimizer.param_groups[0]["lr"] + ) + info2 = "LSC_loss {:.2f}, Spatial_loss {:.2f}, Flat_loss {:.2f}, Train_acc {:.2f}, Test_acc {:.2f}".format( + lsc_losses / (i + 1), + spatial_losses / (i + 1), + flat_losses / (i + 1), + train_acc, + test_acc, + ) + logging.info(info1 + info2) + + +def pod_spatial_loss(old_fmaps, fmaps, normalize=True): + """ + a, b: list of [bs, c, w, h] + """ + loss = torch.tensor(0.0).to(fmaps[0].device) + for i, (a, b) in enumerate(zip(old_fmaps, fmaps)): + assert a.shape == b.shape, "Shape error" + + a = torch.pow(a, 2) + b = torch.pow(b, 2) + + a_h = a.sum(dim=3).view(a.shape[0], -1) # [bs, c*w] + b_h = b.sum(dim=3).view(b.shape[0], -1) # [bs, c*w] + a_w = a.sum(dim=2).view(a.shape[0], -1) # [bs, c*h] + b_w = b.sum(dim=2).view(b.shape[0], -1) # [bs, c*h] + + a = torch.cat([a_h, a_w], dim=-1) + b = torch.cat([b_h, b_w], dim=-1) + + if normalize: + a = F.normalize(a, dim=1, p=2) + b = F.normalize(b, dim=1, p=2) + + layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1)) + loss += layer_loss + + return loss / len(fmaps) + + +def nca( + similarities, + targets, + class_weights=None, + focal_gamma=None, + scale=1.0, + margin=0.6, + exclude_pos_denominator=True, + hinge_proxynca=False, + memory_flags=None, +): + margins = torch.zeros_like(similarities) + margins[torch.arange(margins.shape[0]), targets] = margin + similarities = scale * (similarities - margin) + + if exclude_pos_denominator: + similarities = similarities - similarities.max(1)[0].view(-1, 1) + + disable_pos = torch.zeros_like(similarities) + disable_pos[torch.arange(len(similarities)), targets] = similarities[ + torch.arange(len(similarities)), targets + ] + + numerator = similarities[torch.arange(similarities.shape[0]), targets] + denominator = similarities - disable_pos + + losses = numerator - torch.log(torch.exp(denominator).sum(-1)) + if class_weights is not None: + losses = class_weights[targets] * losses + + losses = -losses + if hinge_proxynca: + losses = torch.clamp(losses, min=0.0) + + loss = torch.mean(losses) + return loss + + return F.cross_entropy( + similarities, targets, weight=class_weights, reduction="mean" + ) diff --git a/models/replay.py b/models/replay.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3e9cf9f88bf21147e569b3a36764fbb8afb7ce --- /dev/null +++ b/models/replay.py @@ -0,0 +1,193 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import IncrementalNet +from utils.toolkit import target2onehot, tensor2numpy + +EPSILON = 1e-8 + + +init_epoch = 100 +init_lr = 0.1 +init_milestones = [40, 60, 80] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 70 +lrate = 0.1 +milestones = [30, 50] +lrate_decay = 0.1 +batch_size = 32 +weight_decay = 2e-4 +num_workers = 8 +T = 2 + + +class Replay(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = IncrementalNet(args, False) + + def after_task(self): + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + # Loader + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + # Procedure + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + self._network.parameters(), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) # 1e-5 + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss_clf = F.cross_entropy(logits, targets) + loss = loss_clf + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + # acc + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) diff --git a/models/rmm.py b/models/rmm.py new file mode 100644 index 0000000000000000000000000000000000000000..bacddb6628b2db13dedc853adf066b873bb0afed --- /dev/null +++ b/models/rmm.py @@ -0,0 +1,285 @@ +import copy +import logging +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from models.foster import FOSTER +from utils.toolkit import count_parameters, tensor2numpy, accuracy +from utils.inc_net import IncrementalNet +from scipy.spatial.distance import cdist +from models.base import BaseLearner +from models.icarl import iCaRL +from tqdm import tqdm +import torch.optim as optim + + +EPSILON = 1e-8 +batch_size = 32 +weight_decay = 2e-4 +num_workers = 8 + + +class RMMBase(BaseLearner): + def __init__(self, args): + self._args = args + self._m_rate_list = args.get("m_rate_list", []) + self._c_rate_list = args.get("c_rate_list", []) + + @property + def samples_per_class(self): + return int(self.memory_size // self._total_classes) + + @property + def memory_size(self): + if self._args["dataset"] == "cifar100": + img_per_cls = 500 + else: + img_per_cls = 1300 + + if self._m_rate_list[self._cur_task] != 0: + print(self._total_classes) + self._memory_size = min(int(self._total_classes*img_per_cls-1),self._args["memory_size"] + int( + self._m_rate_list[self._cur_task] + * self._args["increment"] + * img_per_cls + )) + return self._memory_size + + @property + def new_memory_size(self): + if self._args["dataset"] == "cifar100": + img_per_cls = 500 + else: + img_per_cls = 1300 + return int( + (1 - self._m_rate_list[self._cur_task]) + * self._args["increment"] + * img_per_cls + ) + + def build_rehearsal_memory(self, data_manager, per_class): + self._reduce_exemplar(data_manager, per_class) + self._construct_exemplar(data_manager, per_class) + + def _construct_exemplar(self, data_manager, m): + if self._args["dataset"] == "cifar100": + img_per_cls = 500 + else: + img_per_cls = 1300 + ns = [ + min(img_per_cls,int(m * (1 - self._c_rate_list[self._cur_task]))), + min(img_per_cls,int(m * (1 + self._c_rate_list[self._cur_task]))), + ] + logging.info( + "Constructing exemplars...({} or {} per classes)".format(ns[0], ns[1]) + ) + + all_cls_entropies = [] + ms = [] + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = data_manager.get_dataset( + np.arange(class_idx, class_idx + 1), + source="train", + mode="test", + ret_data=True, + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + with torch.no_grad(): + cidx_cls_entropies = [] + for idx, (_, inputs, targets) in enumerate(idx_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + cross_entropy = ( + F.cross_entropy(logits, targets, reduction="none") + .detach() + .cpu() + .numpy() + ) + cidx_cls_entropies.append(cross_entropy) + # print(cidx_cls_entropies) + cidx_cls_entropies = np.mean(np.concatenate(cidx_cls_entropies)) + all_cls_entropies.append(cidx_cls_entropies) + entropy_median = np.median(all_cls_entropies) + for the_entropy in all_cls_entropies: + if the_entropy > entropy_median: + ms.append(ns[0]) + else: + ms.append(ns[1]) + + logging.info(f"ms: {ms}") + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = data_manager.get_dataset( + np.arange(class_idx, class_idx + 1), + source="train", + mode="test", + ret_data=True, + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + # Select + selected_exemplars = [] + exemplar_vectors = [] # [n, feature_dim] + for k in range(1, ms[class_idx - self._known_classes] + 1): + S = np.sum( + exemplar_vectors, axis=0 + ) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + selected_exemplars.append( + np.array(data[i]) + ) # New object to avoid passing by inference + exemplar_vectors.append( + np.array(vectors[i]) + ) # New object to avoid passing by inference + + vectors = np.delete( + vectors, i, axis=0 + ) # Remove it to avoid duplicative selection + data = np.delete( + data, i, axis=0 + ) # Remove it to avoid duplicative selection + + selected_exemplars = np.array(selected_exemplars) + exemplar_targets = np.full(ms[class_idx - self._known_classes], class_idx) + self._data_memory = ( + np.concatenate((self._data_memory, selected_exemplars)) + if len(self._data_memory) != 0 + else selected_exemplars + ) + self._targets_memory = ( + np.concatenate((self._targets_memory, exemplar_targets)) + if len(self._targets_memory) != 0 + else exemplar_targets + ) + + # Exemplar mean + idx_dataset = data_manager.get_dataset( + [], + source="train", + mode="test", + appendent=(selected_exemplars, exemplar_targets), + ) + idx_loader = DataLoader( + idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4 + ) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean + + +class RMM_iCaRL( + RMMBase, iCaRL +): # RMM Base is supposed to be prior to the orginal method. + def __init__(self, args): + RMMBase.__init__(self, args) + iCaRL.__init__(self, args) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None, + ) + self.train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + +class RMM_FOSTER(RMMBase, FOSTER): + def __init__(self, args): + RMMBase.__init__(self, args) + FOSTER.__init__(self, args) + + def incremental_train(self, data_manager): + self.data_manager = data_manager + self._cur_task += 1 + if self._cur_task > 1: + self._network = self._snet + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + self._network_module_ptr = self._network + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + if self._cur_task > 0: + for p in self._network.convnets[0].parameters(): + p.requires_grad = False + for p in self._network.oldfc.parameters(): + p.requires_grad = False + + logging.info("All params: {}".format(count_parameters(self._network))) + logging.info( + "Trainable params: {}".format(count_parameters(self._network, True)) + ) + + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + m_rate=self._m_rate_list[self._cur_task] if self._cur_task > 0 else None, + ) + self.train_loader = DataLoader( + train_dataset, + batch_size=self.args["batch_size"], + shuffle=True, + num_workers=self.args["num_workers"], + pin_memory=True, + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=self.args["batch_size"], + shuffle=False, + num_workers=self.args["num_workers"], + ) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module diff --git a/models/simplecil.py b/models/simplecil.py new file mode 100644 index 0000000000000000000000000000000000000000..f62cb40ef17fcda62952bc714f734b8c2cfed791 --- /dev/null +++ b/models/simplecil.py @@ -0,0 +1,175 @@ +''' +Re-implementation of SimpleCIL (https://arxiv.org/abs/2303.07338) without pre-trained weights. +The training process is as follows: train the model with cross-entropy in the first stage and replace the classifier with prototypes for all the classes in the subsequent stages. +Please refer to the original implementation (https://github.com/zhoudw-zdw/RevisitingCIL) if you are using pre-trained weights. +''' +import logging +import numpy as np +import torch +from torch import nn +from torch.serialization import load +from tqdm import tqdm +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from utils.inc_net import SimpleCosineIncrementalNet +from models.base import BaseLearner +from utils.toolkit import target2onehot, tensor2numpy + + +num_workers = 8 +batch_size = 32 +milestones = [40, 80] + +class SimpleCIL(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = SimpleCosineIncrementalNet(args, False) + self.min_lr = args['min_lr'] if args['min_lr'] is not None else 1e-8 + self.args = args + + def load_checkpoint(self, filename): + checkpoint = torch.load(filename) + self._total_classes = len(checkpoint["classes"]) + self.class_list = np.array(checkpoint["classes"]) + self.label_list = checkpoint["label_list"] + print("Class list: ", self.class_list) + self._network.update_fc(self._total_classes) + self._network.load_checkpoint(checkpoint["network"]) + self._network.to(self._device) + + def after_task(self): + self._known_classes = self._total_classes + + def save_checkpoint(self, filename): + self._network.cpu() + save_dict = { + "classes": self.data_manager.get_class_list(self._cur_task), + "network": { + "convnet": self._network.convnet.state_dict(), + "fc": self._network.fc.state_dict() + }, + "label_list": self.data_manager.get_label_list(self._cur_task), + } + torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task)) + + def replace_fc(self,trainloader, model, args): + model = model.eval() + embedding_list = [] + label_list = [] + with torch.no_grad(): + for i, batch in enumerate(trainloader): + (_,data,label) = batch + data = data.cuda() + label = label.cuda() + embedding = model(data)["features"] + embedding_list.append(embedding.cpu()) + label_list.append(label.cpu()) + embedding_list = torch.cat(embedding_list, dim=0) + label_list = torch.cat(label_list, dim=0) + + class_list = np.unique(self.train_dataset.labels) + proto_list = [] + for class_index in class_list: + # print('Replacing...',class_index) + data_index = torch.nonzero(label_list == class_index).squeeze(-1) + embedding = embedding_list[data_index] + proto = embedding.mean(0) + if len(self._multiple_gpus) > 1: + self._network.module.fc.weight.data[class_index] = proto + else: + self._network.fc.weight.data[class_index] = proto + return model + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) + self._network.update_fc(self._total_classes) + logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes)) + self.class_list = np.array(data_manager.get_class_list(self._cur_task)) + train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train", ) + self.train_dataset = train_dataset + self.data_manager = data_manager + self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" ) + self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="test", ) + self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=batch_size, shuffle=True, num_workers=num_workers) + + if len(self._multiple_gpus) > 1: + print('Multiple GPUs') + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader, train_loader_for_protonet): + self._network.to(self._device) + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=self.args["init_lr"], + weight_decay=self.args["init_weight_decay"] + ) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=self.args['init_epoch'], eta_min=self.min_lr + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + self.replace_fc(train_loader_for_protonet, self._network, None) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.args["init_epoch"])) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args['init_epoch'], + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args['init_epoch'], + losses / len(train_loader), + train_acc, + ) + elapsed = prog_bar.format_dict["elapsed"] + rate = prog_bar.format_dict["rate"] + remaining = (prog_bar.total - prog_bar.n) / rate if rate and prog_bar.total else 0 # Seconds* + prog_bar.set_description(info) + logging.info("Working on task {}: {:.2f}:{:.2f}".format( + self._cur_task, + elapsed, + remaining)) + logging.info(info) + logging.info("Finised on task {}: {:.2f}".format( + self._cur_task, elapsed)) + + diff --git a/models/ssre.py b/models/ssre.py new file mode 100644 index 0000000000000000000000000000000000000000..37cd93f752530f9c4581190b3d9df93b803629ca --- /dev/null +++ b/models/ssre.py @@ -0,0 +1,253 @@ +import logging +import numpy as np +import os +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader,Dataset +from models.base import BaseLearner +from utils.inc_net import CosineIncrementalNet, FOSTERNet, IncrementalNet +from utils.toolkit import count_parameters, target2onehot, tensor2numpy +from utils.autoaugment import CIFAR10Policy,ImageNetPolicy +from utils.ops import Cutout +from torchvision import datasets, transforms + +EPSILON = 1e-8 + + +class SSRE(BaseLearner): + def __init__(self, args): + super().__init__(args) + self.args = args + self._network = IncrementalNet(args, False) + self._protos = [] + + + + def after_task(self): + self._known_classes = self._total_classes + self._old_network = self._network.copy().freeze() + if hasattr(self._old_network,"module"): + self.old_network_module_ptr = self._old_network.module + else: + self.old_network_module_ptr = self._old_network + #self.save_checkpoint("{}_{}_{}".format(self.args["model_name"],self.args["init_cls"],self.args["increment"])) + def incremental_train(self, data_manager): + self.data_manager = data_manager + if self._cur_task == 0: + self.data_manager._train_trsf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63/255), + CIFAR10Policy(), + Cutout(n_holes=1, length=16) + ] + else: + self.data_manager._train_trsf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63/255), + ] + self._cur_task += 1 + self._total_classes = self._known_classes + \ + data_manager.get_task_size(self._cur_task) + self._network.update_fc(self._total_classes) + self._network_module_ptr = self._network + + logging.info("Model Expansion!") + self._network_expansion() + + logging.info( + 'Learning on {}-{}'.format(self._known_classes, self._total_classes)) + + + logging.info('All params: {}'.format(count_parameters(self._network))) + logging.info('Trainable params: {}'.format( + count_parameters(self._network, True))) + + train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train',mode='train', appendent=self._get_memory()) + if self._cur_task == 0: + batch_size = 64 + else: + batch_size = self.args["batch_size"] + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=self.args["num_workers"], pin_memory=True) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source='test', mode='test') + self.test_loader = DataLoader( + test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["num_workers"]) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + + self._train(self.train_loader, self.test_loader) + + + + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + logging.info("Model Compression!") + + self._network_compression() + def _train(self, train_loader, test_loader): + + resume = False + if self._cur_task in []: + self._network.load_state_dict(torch.load("{}_{}_{}_{}.pkl".format(self.args["model_name"],self.args["init_cls"],self.args["increment"],self._cur_task))["model_state_dict"]) + resume = True + self._network.to(self._device) + if hasattr(self._network, "module"): + self._network_module_ptr = self._network.module + if not resume: + self._epoch_num = self.args["epochs"] + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self._network.parameters( + )), lr=self.args["lr"], weight_decay=self.args["weight_decay"]) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args["step_size"], gamma=self.args["gamma"]) + self._train_function(train_loader, test_loader, optimizer, scheduler) + self._build_protos() + + + def _build_protos(self): + with torch.no_grad(): + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = self.data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + class_mean = np.mean(vectors, axis=0) + self._protos.append(class_mean) + + def train(self): + if self._cur_task > 0: + self._network.eval() + return + self._network.train() + def _train_function(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self._epoch_num)) + for _, epoch in enumerate(prog_bar): + self.train() + losses = 0. + losses_clf, losses_fkd, losses_proto = 0., 0., 0. + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to( + self._device, non_blocking=True), targets.to(self._device, non_blocking=True) + logits, loss_clf, loss_fkd, loss_proto = self._compute_ssre_loss(inputs,targets) + loss = loss_clf + loss_fkd + loss_proto + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + losses_clf += loss_clf.item() + losses_fkd += loss_fkd.item() + losses_proto += loss_proto.item() + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + scheduler.step() + train_acc = np.around(tensor2numpy( + correct)*100 / total, decimals=2) + if epoch % 5 != 0: + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc) + else: + test_acc = self._compute_accuracy(self._network, test_loader) + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_fkd {:.3f}, Loss_proto {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format( + self._cur_task, epoch+1, self._epoch_num, losses/len(train_loader), losses_clf/len(train_loader), losses_fkd/len(train_loader), losses_proto/len(train_loader), train_acc, test_acc) + prog_bar.set_description(info) + logging.info(info) + + def _compute_ssre_loss(self,inputs, targets): + if self._cur_task == 0: + logits = self._network(inputs)["logits"] + loss_clf = F.cross_entropy(logits/self.args["temp"], targets) + return logits, loss_clf, torch.tensor(0.), torch.tensor(0.) + + features = self._network_module_ptr.extract_vector(inputs) # N D + + with torch.no_grad(): + features_old = self.old_network_module_ptr.extract_vector(inputs) + + protos = torch.from_numpy(np.array(self._protos)).to(self._device) # C D + with torch.no_grad(): + weights = F.normalize(features,p=2,dim=1,eps=1e-12) @ F.normalize(protos,p=2,dim=1,eps=1e-12).T + weights = torch.max(weights,dim=1)[0] + # mask = weights > self.args["threshold"] + mask = weights + logits = self._network(inputs)["logits"] + loss_clf = F.cross_entropy(logits/self.args["temp"],targets,reduction="none") + # loss_clf = torch.mean(loss_clf * ~mask) + loss_clf = torch.mean(loss_clf * (1-mask)) + + loss_fkd = torch.norm(features - features_old, p=2, dim=1) + loss_fkd = self.args["lambda_fkd"] * torch.sum(loss_fkd * mask) + + index = np.random.choice(range(self._known_classes),size=self.args["batch_size"],replace=True) + + proto_features = np.array(self._protos)[index] + proto_targets = index + proto_features = proto_features + proto_features = torch.from_numpy(proto_features).float().to(self._device,non_blocking=True) + proto_targets = torch.from_numpy(proto_targets).to(self._device,non_blocking=True) + + + proto_logits = self._network_module_ptr.fc(proto_features)["logits"] + loss_proto = self.args["lambda_proto"] * F.cross_entropy(proto_logits/self.args["temp"], proto_targets) + return logits, loss_clf, loss_fkd, loss_proto + + + def eval_task(self, save_conf=False): + y_pred, y_true = self._eval_cnn(self.test_loader) + cnn_accy = self._evaluate(y_pred, y_true) + + if hasattr(self, '_class_means'): + y_pred, y_true = self._eval_nme(self.test_loader, self._class_means) + nme_accy = self._evaluate(y_pred, y_true) + elif hasattr(self, '_protos'): + y_pred, y_true = self._eval_nme(self.test_loader, self._protos/np.linalg.norm(self._protos,axis=1)[:,None]) + nme_accy = self._evaluate(y_pred, y_true) + else: + nme_accy = None + if save_conf: + _pred = y_pred.T[0] + _pred_path = os.path.join(self.args['logfilename'], "pred.npy") + _target_path = os.path.join(self.args['logfilename'], "target.npy") + np.save(_pred_path, _pred) + np.save(_target_path, y_true) + + _save_dir = os.path.join(f"./results/{self.args['model_name']}/conf_matrix/{self.args['prefix']}") + os.makedirs(_save_dir, exist_ok=True) + _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv") + with open(_save_path, "a+") as f: + f.write(f"{self.args['model_name']},{_pred_path},{_target_path} \n") + return cnn_accy, nme_accy + + def _network_expansion(self): + if self._cur_task > 0: + for p in self._network.convnet.parameters(): + p.requires_grad = True + for k, v in self._network.convnet.named_parameters(): + if 'adapter' not in k: + v.requires_grad = False + # self._network.convnet.re_init_params() # do not use! + self._network.convnet.switch("parallel_adapters") + + def _network_compression(self): + + model_dict = self._network.state_dict() + for k, v in model_dict.items(): + if 'adapter' in k: + k_conv3 = k.replace('adapter', 'conv') + if 'weight' in k: + model_dict[k_conv3] = model_dict[k_conv3] + F.pad(v, [1, 1, 1, 1], 'constant', 0) + model_dict[k] = torch.zeros_like(v) + elif 'bias' in k: + model_dict[k_conv3] = model_dict[k_conv3] + v + model_dict[k] = torch.zeros_like(v) + else: + assert 0 + self._network.load_state_dict(model_dict) + self._network.convnet.switch("normal") \ No newline at end of file diff --git a/models/wa.py b/models/wa.py new file mode 100644 index 0000000000000000000000000000000000000000..23de65687736c4882fc9d18d0fe7068e57bfd730 --- /dev/null +++ b/models/wa.py @@ -0,0 +1,217 @@ +import logging +import numpy as np +from tqdm import tqdm +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import IncrementalNet +from utils.toolkit import target2onehot, tensor2numpy + +EPSILON = 1e-8 + + +init_epoch = 200 +init_lr = 0.1 +init_milestones = [60, 120, 170] +init_lr_decay = 0.1 +init_weight_decay = 0.0005 + + +epochs = 170 +lrate = 0.1 +milestones = [60, 100, 140] +lrate_decay = 0.1 +batch_size = 128 +weight_decay = 2e-4 +num_workers = 8 +T = 2 + + +class WA(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = IncrementalNet(args, False) + + def after_task(self): + if self._cur_task > 0: + self._network.weight_align(self._total_classes - self._known_classes) + self._old_network = self._network.copy().freeze() + self._known_classes = self._total_classes + logging.info("Exemplar size: {}".format(self.exemplar_size)) + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size( + self._cur_task + ) + self._network.update_fc(self._total_classes) + logging.info( + "Learning on {}-{}".format(self._known_classes, self._total_classes) + ) + + # Loader + train_dataset = data_manager.get_dataset( + np.arange(self._known_classes, self._total_classes), + source="train", + mode="train", + appendent=self._get_memory(), + ) + self.train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dataset = data_manager.get_dataset( + np.arange(0, self._total_classes), source="test", mode="test" + ) + self.test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + # Procedure + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader) + self.build_rehearsal_memory(data_manager, self.samples_per_class) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader): + self._network.to(self._device) + if self._old_network is not None: + self._old_network.to(self._device) + + if self._cur_task == 0: + optimizer = optim.SGD( + self._network.parameters(), + momentum=0.9, + lr=init_lr, + weight_decay=init_weight_decay, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay + ) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + optimizer = optim.SGD( + self._network.parameters(), + lr=lrate, + momentum=0.9, + weight_decay=weight_decay, + ) # 1e-5 + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=lrate_decay + ) + self._update_representation(train_loader, test_loader, optimizer, scheduler) + if len(self._multiple_gpus) > 1: + self._network.module.weight_align( + self._total_classes - self._known_classes + ) + else: + self._network.weight_align(self._total_classes - self._known_classes) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(init_epoch)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + init_epoch, + losses / len(train_loader), + train_acc, + ) + + prog_bar.set_description(info) + + logging.info(info) + + def _update_representation(self, train_loader, test_loader, optimizer, scheduler): + kd_lambda = self._known_classes / self._total_classes + prog_bar = tqdm(range(epochs)) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss_clf = F.cross_entropy(logits, targets) + loss_kd = _KD_loss( + logits[:, : self._known_classes], + self._old_network(inputs)["logits"], + T, + ) + + loss = (1-kd_lambda) * loss_clf + kd_lambda * loss_kd + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + # acc + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + if epoch % 5 == 0: + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + test_acc, + ) + else: + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( + self._cur_task, + epoch + 1, + epochs, + losses / len(train_loader), + train_acc, + ) + prog_bar.set_description(info) + logging.info(info) + + +def _KD_loss(pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c08ad666431d8cc241f603b279817c23a6a0375e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +kaggle +numpy==1.21.0 +Pillow==10.3.0 +POT==0.4.0 +quadprog==0.1.12 +scikit_learn +scipy==1.3.3 +tqdm==4.66.2 +Flask +flask_autoindex +boto3 +scikit-learn +python-dotenv \ No newline at end of file diff --git a/resources/ImageNet100.png b/resources/ImageNet100.png new file mode 100644 index 0000000000000000000000000000000000000000..ccde98b17c48ccc3cc45faf9fb468058537f4da3 Binary files /dev/null and b/resources/ImageNet100.png differ diff --git a/resources/cifar100.png b/resources/cifar100.png new file mode 100644 index 0000000000000000000000000000000000000000..2f4ce302a4eb092487aac62dada1056a022fd7da Binary files /dev/null and b/resources/cifar100.png differ diff --git a/resources/imagenet20st5.png b/resources/imagenet20st5.png new file mode 100644 index 0000000000000000000000000000000000000000..e1206cbfe140694ba16747ca567648c07b9520c0 Binary files /dev/null and b/resources/imagenet20st5.png differ diff --git a/resources/logo.png b/resources/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..66c6d5a887d4439158653a73bbe6c19ae312688e Binary files /dev/null and b/resources/logo.png differ diff --git a/rmm_train.py b/rmm_train.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9ecd6f1307471545cc91d31f7a60e03cfa6e92 --- /dev/null +++ b/rmm_train.py @@ -0,0 +1,232 @@ +''' +We implemented `iCaRL+RMM`, `FOSTER+RMM` in [rmm.py](models/rmm.py). We implemented the `Pretraining Stage` of `RMM` in [rmm_train.py](rmm_train.py). +Use the following training script to run it. +```bash +python rmm_train.py --config=./exps/rmm-pretrain.json +``` +''' +import json +import argparse +from trainer import train +import sys +import logging +import copy +import torch +from utils import factory +from utils.data_manager import DataManager +from utils.rl_utils.ddpg import DDPG +from utils.rl_utils.rl_utils import ReplayBuffer +from utils.toolkit import count_parameters +import os +import numpy as np +import random + + +class CILEnv: + def __init__(self, args) -> None: + self._args = copy.deepcopy(args) + self.settings = [(50, 2), (50, 5), (50, 10), (50, 20), (10, 10), (20, 20), (5, 5)] + # self.settings = [(5,5)] # Debug + self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] + self.data_manager = DataManager( + self._args["dataset"], + self._args["shuffle"], + self._args["seed"], + self._args["init_cls"], + self._args["increment"], + ) + self.model = factory.get_model(self._args["model_name"], self._args) + + @property + def nb_task(self): + return self.data_manager.nb_tasks + + @property + def cur_task(self): + return self.model._cur_task + + def get_task_size(self, task_id): + return self.data_manager.get_task_size(task_id) + + def reset(self): + self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] + self.data_manager = DataManager( + self._args["dataset"], + self._args["shuffle"], + self._args["seed"], + self._args["init_cls"], + self._args["increment"], + ) + self.model = factory.get_model(self._args["model_name"], self._args) + + info = "start new task: dataset: {}, init_cls: {}, increment: {}".format( + self._args["dataset"], self._args["init_cls"], self._args["increment"] + ) + return np.array([self.get_task_size(0) / 100, 0]), None, False, info + + def step(self, action): + self.model._m_rate_list.append(action[0]) + self.model._c_rate_list.append(action[1]) + self.model.incremental_train(self.data_manager) + cnn_accy, nme_accy = self.model.eval_task() + self.model.after_task() + done = self.cur_task == self.nb_task - 1 + info = "running task [{}/{}]: dataset: {}, increment: {}, cnn_accy top1: {}, top5: {}".format( + self.model._known_classes, + 100, + self._args["dataset"], + self._args["increment"], + cnn_accy["top1"], + cnn_accy["top5"], + ) + return ( + np.array( + [ + self.get_task_size(self.cur_task+1)/100 if not done else 0., + self.model.memory_size + / (self.model.memory_size + self.model.new_memory_size), + ] + ), + cnn_accy["top1"]/100, + done, + info, + ) + + +def _train(args): + + logs_name = "logs/RL-CIL/{}/".format(args["model_name"]) + if not os.path.exists(logs_name): + os.makedirs(logs_name) + + logfilename = "logs/RL-CIL/{}/{}_{}_{}_{}_{}".format( + args["model_name"], + args["prefix"], + args["seed"], + args["model_name"], + args["convnet_type"], + args["dataset"], + ) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(filename)s] => %(message)s", + handlers=[ + logging.FileHandler(filename=logfilename + ".log"), + logging.StreamHandler(sys.stdout), + ], + ) + + _set_random() + _set_device(args) + print_args(args) + + actor_lr = 5e-4 + critic_lr = 5e-3 + num_episodes = 200 + hidden_dim = 32 + gamma = 0.98 + tau = 0.005 + buffer_size = 1000 + minimal_size = 50 + batch_size = 32 + sigma = 0.2 # action noise, encouraging the off-policy algo to explore. + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + env = CILEnv(args) + replay_buffer = ReplayBuffer(buffer_size) + agent = DDPG( + 2, 1, 4, hidden_dim, False, 1, sigma, actor_lr, critic_lr, tau, gamma, device + ) + for iteration in range(num_episodes): + state, *_, info = env.reset() + logging.info(info) + done = False + while not done: + action = agent.take_action(state) + logging.info(f"take action: m_rate {action[0]}, c_rate {action[1]}") + next_state, reward, done, info = env.step(action) + logging.info(info) + replay_buffer.add(state, action, reward, next_state, done) + state = next_state + if replay_buffer.size() > minimal_size: + b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) + transition_dict = { + "states": b_s, + "actions": b_a, + "next_states": b_ns, + "rewards": b_r, + "dones": b_d, + } + agent.update(transition_dict) + + +def _set_device(args): + device_type = args["device"] + gpus = [] + + for device in device_type: + if device_type == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device)) + + gpus.append(device) + + args["device"] = gpus + + +def _set_random(): + random.seed(1) + torch.manual_seed(1) + torch.cuda.manual_seed(1) + torch.cuda.manual_seed_all(1) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def print_args(args): + for key, value in args.items(): + logging.info("{}: {}".format(key, value)) + + +def train(args): + seed_list = copy.deepcopy(args["seed"]) + device = copy.deepcopy(args["device"]) + + for seed in seed_list: + args["seed"] = seed + args["device"] = device + _train(args) + + +def main(): + args = setup_parser().parse_args() + param = load_json(args.config) + args = vars(args) # Converting argparse Namespace to a dict. + args.update(param) # Add parameters from json + + train(args) + + +def load_json(settings_path): + with open(settings_path) as data_file: + param = json.load(data_file) + + return param + + +def setup_parser(): + parser = argparse.ArgumentParser( + description="Reproduce of multiple continual learning algorthms." + ) + parser.add_argument( + "--config", + type=str, + default="./exps/finetune.json", + help="Json file of settings.", + ) + + return parser + + +if __name__ == "__main__": + main() diff --git a/server.py b/server.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcc0f6bf4b339b73ec177dab05b3df3ae93941e --- /dev/null +++ b/server.py @@ -0,0 +1,89 @@ +from flask import Flask, send_from_directory, request, send_file +from flask_autoindex import AutoIndex +import subprocess, os + +from download_s3_path import download_s3_folder +from download_file_from_s3 import download_from_s3 +from split import split_data +import os +import shutil +import json +import time + +app = Flask(__name__) +app.config["UPLOAD_FOLDER"] = "upload" +AutoIndex(app, browse_root=os.path.curdir) + + +@app.route("/train", methods=["GET"]) +def train(): + try: + subprocess.Popen(["./simple_train.sh"]) + return "Bash script triggered successfully!" + except subprocess.CalledProcessError as e: + return f"An error occurred: {str(e)}", 500 + + +@app.route("/train/workings/", methods=["GET"]) +def train_with_working_id(working_id): + path = f"working/{working_id}" + delete_folder(path) + download_s3_folder(os.getenv("S3_BUCKET_NAME", "pycil.com"), path, path) + + data_path = path + "/data" + config_path = path + "/config.json" + output_path = f"s3://pycil.com/output/{working_id}" + + split_data(data_path) + + subprocess.Popen( + [ + "./train_from_working.sh", + config_path, + data_path, + "models", + f"s3://pycil.com/output/{working_id}/{int(time.time())}", + ] + ) + + return f"Training started with working id {working_id}!" + + +@app.route("/inference", methods=["POST"]) +def infernece(): + file = request.files["image"] + file.save(os.path.join(app.config["UPLOAD_FOLDER"], file.filename)) + + input_path = os.path.join(app.config["UPLOAD_FOLDER"], file.filename) + config_path = request.form["config_path"] + checkpoint_path = request.form["checkpoint_path"] + + download_from_s3("pycil.com", config_path, "config.json") + download_from_s3("pycil.com", checkpoint_path, "checkpoint.pkl") + subprocess.call( + [ + "python", + "inference.py", + "--config", + "config.json", + "--checkpoint", + "checkpoint.pkl", + "--input", + input_path, + "--output", + "output.json", + ] + ) + return send_file("output.json") + + +def delete_folder(folder_path): + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + print(f"Folder '{folder_path}' has been deleted.") + else: + print(f"Folder '{folder_path}' does not exist.") + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=7860, debug=True) diff --git a/simple_train.sh b/simple_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..2bc939de85c343ea30af06ead4a428d9b6b7288a --- /dev/null +++ b/simple_train.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +python main.py --config ./exps/simplecil_general.json --data ./car_data/car_data + +./upload_s3.sh diff --git a/split.py b/split.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4360d9e773e3882729d2465258c02e9ccf5eb7 --- /dev/null +++ b/split.py @@ -0,0 +1,55 @@ +import os +import shutil +import sys +from sklearn.model_selection import train_test_split + + +def split_data(data_dir, train_ratio=0.8, seed=42): + train_dir = os.path.join(data_dir, "train") + val_dir = os.path.join(data_dir, "val") + + # Ensure the train and val directories exist + os.makedirs(train_dir, exist_ok=True) + os.makedirs(val_dir, exist_ok=True) + + # Iterate over each class folder + for class_name in os.listdir(data_dir): + class_path = os.path.join(data_dir, class_name) + if os.path.isdir(class_path) and class_name not in ["train", "val"]: + # Get a list of all files in the class directory + files = os.listdir(class_path) + files = [f for f in files if os.path.isfile(os.path.join(class_path, f))] + + # Split the files into training and validation sets + train_files, val_files = train_test_split( + files, train_size=train_ratio, random_state=seed + ) + + # Create class directories in train and val directories + train_class_dir = os.path.join(train_dir, class_name) + val_class_dir = os.path.join(val_dir, class_name) + os.makedirs(train_class_dir, exist_ok=True) + os.makedirs(val_class_dir, exist_ok=True) + + # Move training files + for file in train_files: + shutil.move( + os.path.join(class_path, file), os.path.join(train_class_dir, file) + ) + + # Move validation files + for file in val_files: + shutil.move( + os.path.join(class_path, file), os.path.join(val_class_dir, file) + ) + + print("Data split complete.") + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python split_data.py ") + sys.exit(1) + + data_dir = sys.argv[1] + split_data(data_dir) diff --git a/static/test.log b/static/test.log new file mode 100644 index 0000000000000000000000000000000000000000..c7b8f380dcfd08cdcdece7a95198f653ea0915ee --- /dev/null +++ b/static/test.log @@ -0,0 +1 @@ +this is a test line \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/test.py @@ -0,0 +1 @@ + diff --git a/test_blur.py b/test_blur.py new file mode 100644 index 0000000000000000000000000000000000000000..95eda4a6e1d0fd08f3625b9fcf27d1cca677ee57 --- /dev/null +++ b/test_blur.py @@ -0,0 +1,30 @@ +from torchvision import transforms +from PIL import Image +import argparse + +def main(): + args = setup_parser().parse_args() + path = args.path + img = Image.open(path) + trf = transforms.Compose([ + transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8), + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness = 0.3, saturation = 0.2), + transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.5, 2.0))], p=1), # Apply Gaussian blur with random probability + ]) + img = trf(img) + img.save("blur.jpg") + +def setup_parser(): + parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') + parser.add_argument('--path', type=str, + help='Image file.') + + return parser + + + +if __name__ == '__main__': + main() + \ No newline at end of file diff --git a/test_upload/test.txt b/test_upload/test.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f872797f1bb936f015eaa34908d5744b420a7b8 --- /dev/null +++ b/test_upload/test.txt @@ -0,0 +1 @@ +this file will be upload to s3 \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..67d0586705cbc27e4aa3a87af770a135ea399dce --- /dev/null +++ b/train.sh @@ -0,0 +1,7 @@ +#!/bin/sh +for arg in $@; do + python ./main.py --config=$arg + # Your commands to process each argument here +done + +./upload_s3.sh \ No newline at end of file diff --git a/train_from_working.sh b/train_from_working.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3159358f8c5a648a8688d08b986679cc8f3333c --- /dev/null +++ b/train_from_working.sh @@ -0,0 +1,16 @@ +#!/bin/sh + +# Ensure the script exits on the first error and prints each command before executing it +set -e +set -x + +# Check if config, data, upload_s3_arg, and s3_path arguments were provided, if not, set default values +config=${1:-./exps/simplecil_general.json} +data=${2:-./car_data/car_data} +upload_s3_arg=${3:-./models} +s3_path=${4:-s3://pycil.com/"$(date -u +"%Y-%m-%dT%H:%M:%SZ")"} + +# Run the training script with the provided or default config and data arguments +python main.py --config "$config" --data "$data" + +./upload_s3.sh "$upload_s3_arg" "$s3_path" diff --git a/train_memo.py b/train_memo.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d3f79e1853188e3aca0edbad7664a6e043aeff --- /dev/null +++ b/train_memo.py @@ -0,0 +1,187 @@ +import sys +import logging +import copy +import torch +from utils import factory +from utils.data_manager import DataManager +from utils.toolkit import count_parameters +import os +import numpy as np + + +def train(args): + seed_list = copy.deepcopy(args["seed"]) + device = copy.deepcopy(args["device"]) + + for seed in seed_list: + args["seed"] = seed + args["device"] = device + _train(args) + + +def _train(args): + + init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"] + logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment']) + + if not os.path.exists(logs_name): + os.makedirs(logs_name) + + save_name = "models/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment']) + + if not os.path.exists(save_name): + os.makedirs(save_name) + logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format( + args["model_name"], + args["dataset"], + init_cls, + args["increment"], + args["prefix"], + args["seed"], + args["convnet_type"], + ) + if not os.path.exists(logs_name): + os.makedirs(logs_name) + args['logfilename'] = logs_name + args['csv_name'] = "{}_{}_{}".format( + args["prefix"], + args["seed"], + args["convnet_type"], + ) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(filename)s] => %(message)s", + handlers=[ + logging.FileHandler(filename=logfilename + ".log"), + logging.StreamHandler(sys.stdout), + ], + ) + + _set_random() + _set_device(args) + print_args(args) + data_manager = DataManager( + args["dataset"], + args["shuffle"], + args["seed"], + args["init_cls"], + args["increment"], + ) + model = factory.get_model(args["model_name"], args) + + cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []} + cnn_matrix, nme_matrix = [], [] + + for task in range(data_manager.nb_tasks): + print(args["device"]) + logging.info("All params: {}".format(count_parameters(model._network))) + logging.info( + "Trainable params: {}".format(count_parameters(model._network, True)) + ) + model.incremental_train(data_manager) + cnn_accy, nme_accy = model.eval_task(save_conf=True) + model.after_task() + + if nme_accy is not None: + logging.info("CNN: {}".format(cnn_accy["grouped"])) + logging.info("NME: {}".format(nme_accy["grouped"])) + + cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] + cnn_keys_sorted = sorted(cnn_keys) + cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] + cnn_matrix.append(cnn_values) + + nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key] + nme_keys_sorted = sorted(nme_keys) + nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted] + nme_matrix.append(nme_values) + + + cnn_curve["top1"].append(cnn_accy["top1"]) + cnn_curve["top5"].append(cnn_accy["top5"]) + + nme_curve["top1"].append(nme_accy["top1"]) + nme_curve["top5"].append(nme_accy["top5"]) + + logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) + logging.info("CNN top5 curve: {}".format(cnn_curve["top5"])) + logging.info("NME top1 curve: {}".format(nme_curve["top1"])) + logging.info("NME top5 curve: {}\n".format(nme_curve["top5"])) + + print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) + print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"])) + + logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) + logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"]))) + else: + logging.info("No NME accuracy.") + logging.info("CNN: {}".format(cnn_accy["grouped"])) + + cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] + cnn_keys_sorted = sorted(cnn_keys) + cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] + cnn_matrix.append(cnn_values) + + cnn_curve["top1"].append(cnn_accy["top1"]) + cnn_curve["top5"].append(cnn_accy["top5"]) + + logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) + logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"])) + + print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) + logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) + model.save_checkpoint(save_name) + + if len(cnn_matrix)>0: + np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))]) + for idxx, line in enumerate(cnn_matrix): + idxy = len(line) + np_acctable[idxx, :idxy] = np.array(line) + np_acctable = np_acctable.T + forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1]) + logging.info('Forgetting (CNN): {}'.format(forgetting)) + logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable)) + print('Accuracy Matrix (CNN):') + print(np_acctable) + print('Forgetting (CNN):', forgetting) + if len(nme_matrix)>0: + np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))]) + for idxx, line in enumerate(nme_matrix): + idxy = len(line) + np_acctable[idxx, :idxy] = np.array(line) + np_acctable = np_acctable.T + forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1]) + logging.info('Forgetting (NME): {}'.format(forgetting)) + logging.info('Accuracy Matrix (NME): {}'.format(np_acctable)) + print('Accuracy Matrix (NME):') + print(np_acctable) + print('Forgetting (NME):', forgetting) + + +def _set_device(args): + device_type = args["device"] + gpus = [] + + for device in device_type: + if device_type == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device)) + + gpus.append(device) + + args["device"] = gpus + + +def _set_random(): + torch.manual_seed(1) + torch.cuda.manual_seed(1) + torch.cuda.manual_seed_all(1) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def print_args(args): + for key, value in args.items(): + logging.info("{}: {}".format(key, value)) + diff --git a/train_more.py b/train_more.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2a0cd24bdbe11e7ee37c73663c57b5e6058712 --- /dev/null +++ b/train_more.py @@ -0,0 +1,186 @@ +import sys +import logging +import copy +import torch +from utils import factory +from utils.data_manager import DataManager +from utils.toolkit import count_parameters +import os +import numpy as np +from load_model import load_model, get_methods + +def train_more(args): + seed_list = copy.deepcopy(args["seed"]) + device = copy.deepcopy(args["device"]) + + for seed in seed_list: + args["seed"] = seed + args["device"] = device + _train_more(args) + + +def _train_more(args): + + init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"] + logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment']) + + if not os.path.exists(logs_name): + os.makedirs(logs_name) + + save_name = "models/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment']) + + if not os.path.exists(save_name): + os.makedirs(save_name) + logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format( + args["model_name"], + args["dataset"], + init_cls, + args["increment"], + args["prefix"], + args["seed"], + args["convnet_type"], + ) + if not os.path.exists(logs_name): + os.makedirs(logs_name) + args['logfilename'] = logs_name + args['csv_name'] = "{}_{}_{}".format( + args["prefix"], + args["seed"], + args["convnet_type"], + ) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(filename)s] => %(message)s", + handlers=[ + logging.FileHandler(filename=logfilename + ".log"), + logging.StreamHandler(sys.stdout), + ], + ) + + _set_random() + print_args(args) + model = load_model(args) + data_manager = DataManager( + args["dataset"], + args["shuffle"], + args["seed"], + args["init_cls"], + args["increment"], + resume = True, + path = args["data"], + class_list = model.class_list + ) + cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []} + cnn_matrix, nme_matrix = [], [] + + for task in range(data_manager.nb_tasks): + print(args["device"]) + logging.info("All params: {}".format(count_parameters(model._network))) + logging.info( + "Trainable params: {}".format(count_parameters(model._network, True)) + ) + model.incremental_train(data_manager) + cnn_accy, nme_accy = model.eval_task(save_conf=True) + model.after_task() + + if nme_accy is not None: + logging.info("CNN: {}".format(cnn_accy["grouped"])) + logging.info("NME: {}".format(nme_accy["grouped"])) + + cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] + cnn_keys_sorted = sorted(cnn_keys) + cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] + cnn_matrix.append(cnn_values) + + nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key] + nme_keys_sorted = sorted(nme_keys) + nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted] + nme_matrix.append(nme_values) + + + cnn_curve["top1"].append(cnn_accy["top1"]) + cnn_curve["top5"].append(cnn_accy["top5"]) + + nme_curve["top1"].append(nme_accy["top1"]) + nme_curve["top5"].append(nme_accy["top5"]) + + logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) + logging.info("CNN top5 curve: {}".format(cnn_curve["top5"])) + logging.info("NME top1 curve: {}".format(nme_curve["top1"])) + logging.info("NME top5 curve: {}\n".format(nme_curve["top5"])) + + print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) + print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"])) + + logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) + logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"]))) + else: + logging.info("No NME accuracy.") + logging.info("CNN: {}".format(cnn_accy["grouped"])) + + cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] + cnn_keys_sorted = sorted(cnn_keys) + cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] + cnn_matrix.append(cnn_values) + + cnn_curve["top1"].append(cnn_accy["top1"]) + cnn_curve["top5"].append(cnn_accy["top5"]) + + logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) + logging.info("CNN top5 curve: {}\n".format(cnn_curve["top5"])) + + print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) + logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) + model.save_checkpoint(save_name) + if len(cnn_matrix)>0: + np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))]) + for idxx, line in enumerate(cnn_matrix): + idxy = len(line) + np_acctable[idxx, :idxy] = np.array(line) + np_acctable = np_acctable.T + forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1]) + logging.info('Forgetting (CNN): {}'.format(forgetting)) + logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable)) + print('Accuracy Matrix (CNN):') + print(np_acctable) + print('Forgetting (CNN):', forgetting) + if len(nme_matrix)>0: + np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))]) + for idxx, line in enumerate(nme_matrix): + idxy = len(line) + np_acctable[idxx, :idxy] = np.array(line) + np_acctable = np_acctable.T + forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1]) + logging.info('Forgetting (NME): {}'.format(forgetting)) + logging.info('Accuracy Matrix (NME): {}'.format(np_acctable)) + print('Accuracy Matrix (NME):') + print(np_acctable) + print('Forgetting (NME):', forgetting) + +def _set_device(args): + device_type = args["device"] + gpus = [] + + for device in device_type: + if device == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device)) + + gpus.append(device) + + args["device"] = gpus + + +def _set_random(): + torch.manual_seed(1) + torch.cuda.manual_seed(1) + torch.cuda.manual_seed_all(1) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def print_args(args): + for key, value in args.items(): + logging.info("{}: {}".format(key, value)) + diff --git a/train_more.sh b/train_more.sh new file mode 100644 index 0000000000000000000000000000000000000000..e3d10e2617d8cf36b85348dbcf2adb5b067fd83b --- /dev/null +++ b/train_more.sh @@ -0,0 +1,5 @@ +#! /bin/sh +for arg in $@; do + python ./main.py --config=$arg --resume + # Your commands to process each argument here +done diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..09cd4ee86e3e48deea77add4d9b8679c3113dc3c --- /dev/null +++ b/trainer.py @@ -0,0 +1,192 @@ +import sys +import logging +import copy +import torch +from utils import factory +from utils.data_manager import DataManager +from utils.toolkit import count_parameters +import os +import numpy as np + + +def train(args): + seed_list = copy.deepcopy(args["seed"]) + device = copy.deepcopy(args["device"]) + + for seed in seed_list: + args["seed"] = seed + args["device"] = device + _train(args) + + +def _train(args): + + init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"] + logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], init_cls, args['increment']) + + if not os.path.exists(logs_name): + os.makedirs(logs_name) + + save_name = "models/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], init_cls, args['increment']) + + if not os.path.exists(save_name): + os.makedirs(save_name) + if not os.path.exists(logs_name): + os.makedirs(logs_name) + logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format( + args["model_name"], + args["dataset"], + args['data'], + init_cls, + args["increment"], + args["prefix"], + args["seed"], + args["convnet_type"], + ) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(filename)s] => %(message)s", + handlers=[ + logging.FileHandler(filename=logfilename + ".log"), + logging.StreamHandler(sys.stdout), + ], + force=True + ) + args['logfilename'] = logs_name + args['csv_name'] = "{}_{}_{}".format( + args["prefix"], + args["seed"], + args["convnet_type"], + ) + + + _set_random() + _set_device(args) + print_args(args) + model = factory.get_model(args["model_name"], args) + data_manager = DataManager( + args["dataset"], + args["shuffle"], + args["seed"], + args["init_cls"], + args["increment"], + path = args["data"], + ) + if data_manager.get_task_size(0) < 5: + top_string = "top{}".format(data_manager.get_task_size(0)) + else: + top_string = "top5" + cnn_curve, nme_curve = {"top1": [], top_string: []}, {"top1": [], top_string: []} + cnn_matrix, nme_matrix = [], [] + + for task in range(data_manager.nb_tasks): + print(args["device"]) + logging.info("All params: {}".format(count_parameters(model._network))) + logging.info( + "Trainable params: {}".format(count_parameters(model._network, True)) + ) + model.incremental_train(data_manager) + cnn_accy, nme_accy = model.eval_task(save_conf=True) + model.after_task() + + if nme_accy is not None: + logging.info("CNN: {}".format(cnn_accy["grouped"])) + logging.info("NME: {}".format(nme_accy["grouped"])) + + cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] + cnn_keys_sorted = sorted(cnn_keys) + cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] + cnn_matrix.append(cnn_values) + + nme_keys = [key for key in nme_accy["grouped"].keys() if '-' in key] + nme_keys_sorted = sorted(nme_keys) + nme_values = [nme_accy["grouped"][key] for key in nme_keys_sorted] + nme_matrix.append(nme_values) + + + cnn_curve["top1"].append(cnn_accy["top1"]) + cnn_curve[top_string].append(cnn_accy["top{}".format(model.topk)]) + + nme_curve["top1"].append(nme_accy["top1"]) + nme_curve[top_string].append(nme_accy["top{}".format(model.topk)]) + + logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) + logging.info("CNN top5 curve: {}".format(cnn_curve[top_string])) + logging.info("NME top1 curve: {}".format(nme_curve["top1"])) + logging.info("NME top5 curve: {}\n".format(nme_curve[top_string])) + + print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) + print('Average Accuracy (NME):', sum(nme_curve["top1"])/len(nme_curve["top1"])) + + logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) + logging.info("Average Accuracy (NME): {}".format(sum(nme_curve["top1"])/len(nme_curve["top1"]))) + else: + logging.info("No NME accuracy.") + logging.info("CNN: {}".format(cnn_accy["grouped"])) + + cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] + cnn_keys_sorted = sorted(cnn_keys) + cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys_sorted] + cnn_matrix.append(cnn_values) + + cnn_curve["top1"].append(cnn_accy["top1"]) + cnn_curve[top_string].append(cnn_accy["top{}".format(model.topk)]) + + logging.info("CNN top1 curve: {}".format(cnn_curve["top1"])) + logging.info("CNN top5 curve: {}\n".format(cnn_curve[top_string])) + + print('Average Accuracy (CNN):', sum(cnn_curve["top1"])/len(cnn_curve["top1"])) + logging.info("Average Accuracy (CNN): {}".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) + model.save_checkpoint(save_name) + if len(cnn_matrix)>0: + np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))]) + for idxx, line in enumerate(cnn_matrix): + idxy = len(line) + np_acctable[idxx, :idxy] = np.array(line) + np_acctable = np_acctable.T + forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1]) + logging.info('Forgetting (CNN): {}'.format(forgetting)) + logging.info('Accuracy Matrix (CNN): {}'.format(np_acctable)) + print('Accuracy Matrix (CNN):') + print(np_acctable) + print('Forgetting (CNN):', forgetting) + if len(nme_matrix)>0: + np_acctable = np.zeros([ task + 1, int((args["init_cls"] // 10) + task * (args["increment"] // 10))]) + for idxx, line in enumerate(nme_matrix): + idxy = len(line) + np_acctable[idxx, :idxy] = np.array(line) + np_acctable = np_acctable.T + forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, -1])[:-1]) + logging.info('Forgetting (NME): {}'.format(forgetting)) + logging.info('Accuracy Matrix (NME): {}'.format(np_acctable)) + print('Accuracy Matrix (NME):') + print(np_acctable) + print('Forgetting (NME):', forgetting) + +def _set_device(args): + device_type = args["device"] + gpus = [] + + for device in device_type: + if device == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device)) + + gpus.append(device) + + args["device"] = gpus + + +def _set_random(): + torch.manual_seed(1) + torch.cuda.manual_seed(1) + torch.cuda.manual_seed_all(1) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def print_args(args): + for key, value in args.items(): + logging.info("{}: {}".format(key, value)) + diff --git a/upload_s3.sh b/upload_s3.sh new file mode 100644 index 0000000000000000000000000000000000000000..97d4e956d07481b9e060b6f7353bfd04f111de84 --- /dev/null +++ b/upload_s3.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +# Ensure the script exits on the first error and prints each command before executing it +set -e +set -x + +# Check if local directory and s3 path arguments were provided, if not, set default values +local_dir=${1:-./models} +s3_path=${2:-s3://pycil.com/"$(date -u +"%Y-%m-%dT%H:%M:%SZ")"} + +# Perform the S3 copy operation with the provided or default s3 path +aws s3 cp "$local_dir" "$s3_path" --recursive diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/autoaugment.py b/utils/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..ee8c90b1eba9867324212132edbdb5c570c911f5 --- /dev/null +++ b/utils/autoaugment.py @@ -0,0 +1,215 @@ +import numpy as np +from .ops import * + + +class ImageNetPolicy(object): + """ Randomly choose one of the best 24 Sub-policies on ImageNet. + + Example: + >>> policy = ImageNetPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform = transforms.Compose([ + >>> transforms.Resize(256), + >>> ImageNetPolicy(), + >>> transforms.ToTensor()]) + """ + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + + SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), + SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), + SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), + SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), + + SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), + SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), + SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + + SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), + SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), + SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), + SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), + SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), + + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment ImageNet Policy" + + +class CIFAR10Policy(object): + """ Randomly choose one of the best 25 Sub-policies on CIFAR10. + + Example: + >>> policy = CIFAR10Policy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> CIFAR10Policy(), + >>> transforms.ToTensor()]) + """ + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), + SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), + SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), + + SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), + SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), + SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), + + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), + SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), + SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), + + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), + SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), + SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), + + SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment CIFAR10 Policy" + + +class SVHNPolicy(object): + """ Randomly choose one of the best 25 Sub-policies on SVHN. + + Example: + >>> policy = SVHNPolicy() + >>> transformed = policy(image) + + Example as a PyTorch Transform: + >>> transform=transforms.Compose([ + >>> transforms.Resize(256), + >>> SVHNPolicy(), + >>> transforms.ToTensor()]) + """ + def __init__(self, fillcolor=(128, 128, 128)): + self.policies = [ + SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), + SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), + SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), + SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), + SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), + + SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), + SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), + SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), + SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), + SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), + + SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), + SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), + SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), + SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), + SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), + + SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), + SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), + SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), + SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), + SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) + ] + + def __call__(self, img): + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + return "AutoAugment SVHN Policy" + + +class SubPolicy(object): + def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): + ranges = { + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 150 / 331, 10), + "translateY": np.linspace(0, 150 / 331, 10), + "rotate": np.linspace(0, 30, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int), + "solarize": np.linspace(256, 0, 10), + "contrast": np.linspace(0.0, 0.9, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.9, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10 + } + + func = { + "shearX": ShearX(fillcolor=fillcolor), + "shearY": ShearY(fillcolor=fillcolor), + "translateX": TranslateX(fillcolor=fillcolor), + "translateY": TranslateY(fillcolor=fillcolor), + "rotate": Rotate(), + "color": Color(), + "posterize": Posterize(), + "solarize": Solarize(), + "contrast": Contrast(), + "sharpness": Sharpness(), + "brightness": Brightness(), + "autocontrast": AutoContrast(), + "equalize": Equalize(), + "invert": Invert() + } + + self.p1 = p1 + self.operation1 = func[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + self.p2 = p2 + self.operation2 = func[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + if random.random() < self.p1: + img = self.operation1(img, self.magnitude1) + if random.random() < self.p2: + img = self.operation2(img, self.magnitude2) + return img diff --git a/utils/data.py b/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..c161973d316392412f346520c729e19664d04a5a --- /dev/null +++ b/utils/data.py @@ -0,0 +1,199 @@ +import numpy as np +from torchvision import datasets, transforms +from utils.toolkit import split_images_labels + +import os + +class iData(object): + train_trsf = [] + test_trsf = [] + common_trsf = [] + class_order = None + +class iCIFAR10(iData): + use_path = False + train_trsf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter(brightness=63 / 255), + transforms.ToTensor(), + ] + test_trsf = [transforms.ToTensor()] + common_trsf = [ + transforms.Normalize( + mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) + ), + ] + + class_order = np.arange(10).tolist() + + def download_data(self): + train_dataset = datasets.cifar.CIFAR10("./data", train=True, download=True) + test_dataset = datasets.cifar.CIFAR10("./data", train=False, download=True) + self.train_data, self.train_targets = train_dataset.data, np.array( + train_dataset.targets + ) + self.test_data, self.test_targets = test_dataset.data, np.array( + test_dataset.targets + ) + + +class iCIFAR100(iData): + use_path = False + train_trsf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255), + transforms.ToTensor() + ] + test_trsf = [transforms.ToTensor()] + common_trsf = [ + transforms.Normalize( + mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) + ), + ] + + class_order = np.arange(100).tolist() + + def download_data(self): + train_dataset = datasets.cifar.CIFAR100("./data", train=True, download=True) + test_dataset = datasets.cifar.CIFAR100("./data", train=False, download=True) + self.train_data, self.train_targets = train_dataset.data, np.array( + train_dataset.targets + ) + self.test_data, self.test_targets = test_dataset.data, np.array( + test_dataset.targets + ) + + +class iImageNet1000(iData): + use_path = True + train_trsf = [ + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8), + transforms.ColorJitter(), + ] + test_trsf = [ + transforms.Resize(256), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.470, 0.460, 0.455], + std=[0.267, 0.266, 0.270] + ), + ] + + class_order = np.arange(1000).tolist() + + def download_data(self): + assert 0, "You should specify the folder of your dataset" + train_dir = "[DATA-PATH]/train/" + test_dir = "[DATA-PATH]/val/" + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + + +class StanfordCar(iData): + use_path = True + train_trsf = [ + transforms.Resize(320), + transforms.CenterCrop(320), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8), + transforms.ColorJitter(), + ] + test_trsf = [ + transforms.Resize(320), + transforms.CenterCrop(320), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.470, 0.460, 0.455], + std=[0.267, 0.266, 0.270] + ), + ] + class_order = np.arange(196).tolist() + def download_data(self): + path = './car_data/car_data' + train_dset = datasets.ImageFolder(os.path.join(path, "train")) + test_dset = datasets.ImageFolder(os.path.join(path, "test")) + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + +class GeneralDataset(iData): + def __init__( + self, + path, + init_class_list = [-1], + train_transform = None, + test_transform = None, + common_transform = None): + self.use_path = True + self.path = path + self.train_trsf = train_transform + if self.train_trsf == None: + self.train_trsf = [ + transforms.RandomAffine(25, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=8), + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness = 0.3, saturation = 0.2), + ] + self.test_trsf = test_transform + if self.test_trsf == None: + self.test_trsf = [ + transforms.Resize(224), + transforms.CenterCrop(224), + ] + self.common_trsf = common_transform + if self.common_trsf == None: + self.common_trsf = [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5] + ), + ] + self.init_index = max(init_class_list) + 1 + self.class_order = np.arange(self.init_index, self.init_index + len(os.listdir(os.path.join(self.path, "train")))) + + def download_data(self): + train_dset = datasets.ImageFolder(os.path.join(self.path, "train")) + test_dset = datasets.ImageFolder(os.path.join(self.path, "val")) + self.train_data, self.train_targets = split_images_labels(train_dset.imgs, start_index = self.init_index) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs, start_index = self.init_index) + return train_dset.classes + +class iImageNet100(iData): + use_path = True + train_trsf = [ + transforms.Resize(320), + transforms.CenterCrop(320), + ] + test_trsf = [ + transforms.Resize(320), + transforms.CenterCrop(320), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + + class_order = np.arange(1000).tolist() + + def download_data(self): + assert 0, "You should specify the folder of your dataset" + train_dir = "[DATA-PATH]/train/" + test_dir = "[DATA-PATH]/val/" + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) diff --git a/utils/data_manager.py b/utils/data_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..78b46eb9d9f3c2a3a192e68a472288a12b2f6056 --- /dev/null +++ b/utils/data_manager.py @@ -0,0 +1,335 @@ +import logging +import numpy as np +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000, StanfordCar, GeneralDataset +from tqdm import tqdm +class DataManager(object): + def __init__(self, dataset_name, shuffle, seed, init_cls, increment, resume = False, path = None, class_list = [-1]): + self.dataset_name = dataset_name + self.init_class_list = class_list + if not resume: + data = { + "path": path, + "class_list": [-1], + } + self._setup_data(dataset_name, shuffle, seed, data = data) + if len(self._class_order) < init_cls: + self._increments = [len(self._class_order)] + else: + self._increments = [init_cls] + while sum(self._increments) + increment < len(self._class_order): + self._increments.append(increment) + offset = len(self._class_order) - sum(self._increments) + if offset > 0: + self._increments.append(offset) + else: + self._increments = [max(class_list)] + data = { + "path": path, + "class_list": class_list, + } + self._setup_data(dataset_name, shuffle, seed, data = data) + while sum(self._increments) + increment < len(self._class_order): + self._increments.append(increment) + offset = len(self._class_order) - sum(self._increments) - 1 + if offset > 0: + self._increments.append(offset) + def get_class_list(self, task): + return self._class_order[: sum(self._increments[: task + 1])] + def get_label_list(self, task): + cls_list = self.get_class_list(task) + start_index = max(self.init_class_list) + 1 + result = {i:self.label_list[i] for i in cls_list} + return result + @property + def nb_tasks(self): + return len(self._increments) + + def get_task_size(self, task): + return self._increments[task] + + def get_accumulate_tasksize(self,task): + return float(sum(self._increments[:task+1])) + + def get_total_classnum(self): + return len(self._class_order) + + def get_dataset( + self, indices, source, mode, appendent=None, ret_data=False, m_rate=None + ): + if source == "train": + x, y = self._train_data, self._train_targets + elif source == "test": + x, y = self._test_data, self._test_targets + else: + raise ValueError("Unknown data source {}.".format(source)) + + if mode == "train": + trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) + elif mode == "flip": + trsf = transforms.Compose( + [ + *self._test_trsf, + transforms.RandomHorizontalFlip(p=1.0), + *self._common_trsf, + ] + ) + elif mode == "test": + trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) + else: + raise ValueError("Unknown mode {}.".format(mode)) + + data, targets = [], [] + for idx in indices: + if m_rate is None: + class_data, class_targets = self._select( + x, y, low_range=idx, high_range=idx + 1 + ) + else: + class_data, class_targets = self._select_rmm( + x, y, low_range=idx, high_range=idx + 1, m_rate=m_rate + ) + data.append(class_data) + targets.append(class_targets) + + if appendent is not None and len(appendent) != 0: + appendent_data, appendent_targets = appendent + data.append(appendent_data) + targets.append(appendent_targets) + + data, targets = np.concatenate(data), np.concatenate(targets) + if ret_data: + return data, targets, DummyDataset(data, targets, trsf, self.use_path) + else: + return DummyDataset(data, targets, trsf, self.use_path) + + + def get_finetune_dataset(self,known_classes,total_classes,source,mode,appendent,type="ratio"): + if source == 'train': + x, y = self._train_data, self._train_targets + elif source == 'test': + x, y = self._test_data, self._test_targets + else: + raise ValueError('Unknown data source {}.'.format(source)) + + if mode == 'train': + trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) + elif mode == 'test': + trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) + else: + raise ValueError('Unknown mode {}.'.format(mode)) + val_data = [] + val_targets = [] + + old_num_tot = 0 + appendent_data, appendent_targets = appendent + + for idx in range(0, known_classes): + append_data, append_targets = self._select(appendent_data, appendent_targets, + low_range=idx, high_range=idx+1) + num=len(append_data) + if num == 0: + continue + old_num_tot += num + val_data.append(append_data) + val_targets.append(append_targets) + if type == "ratio": + new_num_tot = int(old_num_tot*(total_classes-known_classes)/known_classes) + elif type == "same": + new_num_tot = old_num_tot + else: + assert 0, "not implemented yet" + new_num_average = int(new_num_tot/(total_classes-known_classes)) + for idx in range(known_classes,total_classes): + class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) + val_indx = np.random.choice(len(class_data),new_num_average, replace=False) + val_data.append(class_data[val_indx]) + val_targets.append(class_targets[val_indx]) + val_data=np.concatenate(val_data) + val_targets = np.concatenate(val_targets) + return DummyDataset(val_data, val_targets, trsf, self.use_path) + + def get_dataset_with_split( + self, indices, source, mode, appendent=None, val_samples_per_class=0 + ): + if source == "train": + x, y = self._train_data, self._train_targets + elif source == "test": + x, y = self._test_data, self._test_targets + else: + raise ValueError("Unknown data source {}.".format(source)) + + if mode == "train": + trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) + elif mode == "test": + trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) + else: + raise ValueError("Unknown mode {}.".format(mode)) + + train_data, train_targets = [], [] + val_data, val_targets = [], [] + for idx in indices: + class_data, class_targets = self._select( + x, y, low_range=idx, high_range=idx + 1 + ) + val_indx = np.random.choice( + len(class_data), val_samples_per_class, replace=False + ) + train_indx = list(set(np.arange(len(class_data))) - set(val_indx)) + val_data.append(class_data[val_indx]) + val_targets.append(class_targets[val_indx]) + train_data.append(class_data[train_indx]) + train_targets.append(class_targets[train_indx]) + + if appendent is not None: + appendent_data, appendent_targets = appendent + for idx in range(0, int(np.max(appendent_targets)) + 1): + append_data, append_targets = self._select( + appendent_data, appendent_targets, low_range=idx, high_range=idx + 1 + ) + val_indx = np.random.choice( + len(append_data), val_samples_per_class, replace=False + ) + train_indx = list(set(np.arange(len(append_data))) - set(val_indx)) + val_data.append(append_data[val_indx]) + val_targets.append(append_targets[val_indx]) + train_data.append(append_data[train_indx]) + train_targets.append(append_targets[train_indx]) + + train_data, train_targets = np.concatenate(train_data), np.concatenate( + train_targets + ) + val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets) + + return DummyDataset( + train_data, train_targets, trsf, self.use_path + ), DummyDataset(val_data, val_targets, trsf, self.use_path) + + def _setup_data(self, dataset_name, shuffle, seed, data = None): + idata = _get_idata(dataset_name, data = data) + self.label_list = idata.download_data() + # Data + self._train_data, self._train_targets = idata.train_data, idata.train_targets + self._test_data, self._test_targets = idata.test_data, idata.test_targets + self.use_path = idata.use_path + # Transforms + self._train_trsf = idata.train_trsf + self._test_trsf = idata.test_trsf + self._common_trsf = idata.common_trsf + + # Order + order = np.unique(self._train_targets) + if shuffle: + np.random.seed(seed) + order = np.random.permutation(order).tolist() + else: + order = idata.class_order.tolist() + if data['class_list'][0] != -1: + self._class_order = np.concatenate((np.array(data['class_list']), order)).tolist() + else: + self._class_order = order + logging.info(self._class_order) + # Map indices + self._train_targets = _map_new_class_index( + self._train_targets, self._class_order, + ) + self._test_targets = _map_new_class_index(self._test_targets, self._class_order) + + def _select(self, x, y, low_range, high_range): + idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] + if isinstance(x,np.ndarray): + x_return = x[idxes] + else: + x_return = [] + for id in idxes: + x_return.append(x[id]) + return x_return, y[idxes] + + def _select_rmm(self, x, y, low_range, high_range, m_rate): + assert m_rate is not None + if m_rate != 0: + idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] + selected_idxes = np.random.randint( + 0, len(idxes), size=int((1 - m_rate) * len(idxes)) + ) + new_idxes = idxes[selected_idxes] + new_idxes = np.sort(new_idxes) + else: + new_idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] + return x[new_idxes], y[new_idxes] + + def getlen(self, index): + y = self._train_targets + return np.sum(np.where(y == index)) + + +class DummyDataset(Dataset): + def __init__(self, images, labels, trsf, use_path=False): + assert len(images) == len(labels), "Data size error!" + self.images = images + self.labels = labels + self.trsf = trsf + self.use_path = use_path + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + if self.use_path: + image = self.trsf(pil_loader(self.images[idx])) + else: + image = self.trsf(Image.fromarray(self.images[idx])) + label = self.labels[idx] + + return idx, image, label + + +def _map_new_class_index(y, order): + return np.array(list(map(lambda x: order.index(x), y))) + + +def _get_idata(dataset_name, data = None): + name = dataset_name.lower() + if name == "cifar10": + return iCIFAR10() + elif name == "cifar100": + return iCIFAR100() + elif name == "imagenet1000": + return iImageNet1000() + elif name == "imagenet100": + return iImageNet100() + elif name == 'stanfordcar': + return StanfordCar() + elif name == 'general_dataset': + print(data) + return GeneralDataset(data["path"], init_class_list = data["class_list"]); + else: + raise NotImplementedError("Unknown dataset {}.".format(dataset_name)) + + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, "rb") as f: + img = Image.open(f) + return img.convert("RGB") + + +def accimage_loader(path): + import accimage + + try: + return accimage.Image(path) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + return pil_loader(path) + + +def default_loader(path): + from torchvision import get_image_backend + + if get_image_backend() == "accimage": + return accimage_loader(path) + else: + return pil_loader(path) diff --git a/utils/factory.py b/utils/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..41ed69d14cbf7736b38791d4bb10759dd06b2cc4 --- /dev/null +++ b/utils/factory.py @@ -0,0 +1,67 @@ +def get_model(model_name, args): + name = model_name.lower() + if name == "icarl": + from models.icarl import iCaRL + return iCaRL(args) + elif name == "bic": + from models.bic import BiC + return BiC(args) + elif name == "podnet": + from models.podnet import PODNet + return PODNet(args) + elif name == "lwf": + from models.lwf import LwF + return LwF(args) + elif name == "ewc": + from models.ewc import EWC + return EWC(args) + elif name == "wa": + from models.wa import WA + return WA(args) + elif name == "der": + from models.der import DER + return DER(args) + elif name == "finetune": + from models.finetune import Finetune + return Finetune(args) + elif name == "replay": + from models.replay import Replay + return Replay(args) + elif name == "gem": + from models.gem import GEM + return GEM(args) + elif name == "coil": + from models.coil import COIL + return COIL(args) + elif name == "foster": + from models.foster import FOSTER + return FOSTER(args) + elif name == "rmm-icarl": + from models.rmm import RMM_FOSTER, RMM_iCaRL + return RMM_iCaRL(args) + elif name == "rmm-foster": + from models.rmm import RMM_FOSTER, RMM_iCaRL + return RMM_FOSTER(args) + elif name == "fetril": + from models.fetril import FeTrIL + return FeTrIL(args) + elif name == "pass": + from models.pa2s import PASS + return PASS(args) + elif name == "il2a": + from models.il2a import IL2A + return IL2A(args) + elif name == "ssre": + from models.ssre import SSRE + return SSRE(args) + elif name == "memo": + from models.memo import MEMO + return MEMO(args) + elif name == "beefiso": + from models.beef_iso import BEEFISO + return BEEFISO(args) + elif name == "simplecil": + from models.simplecil import SimpleCIL + return SimpleCIL(args) + else: + assert 0 diff --git a/utils/inc_net.py b/utils/inc_net.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe3721aedc95ccf0169b911369fa6f525f83d5 --- /dev/null +++ b/utils/inc_net.py @@ -0,0 +1,799 @@ +import copy +import logging +import torch +from torch import nn +from convs.cifar_resnet import resnet32 +from convs.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 +from convs.ucir_cifar_resnet import resnet32 as cosine_resnet32 +from convs.ucir_resnet import resnet18 as cosine_resnet18 +from convs.ucir_resnet import resnet34 as cosine_resnet34 +from convs.ucir_resnet import resnet50 as cosine_resnet50 +from convs.linears import SimpleLinear, SplitCosineLinear, CosineLinear +from convs.modified_represnet import resnet18_rep,resnet34_rep +from convs.resnet_cbam import resnet18_cbam,resnet34_cbam,resnet50_cbam +from convs.memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for MEMO imagenet +from convs.memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for MEMO cifar + +def get_convnet(args, pretrained=False): + name = args["convnet_type"].lower() + if name == "resnet32": + return resnet32() + elif name == "resnet18": + return resnet18(pretrained=pretrained,args=args) + elif name == "resnet34": + return resnet34(pretrained=pretrained,args=args) + elif name == "resnet50": + return resnet50(pretrained=pretrained,args=args) + elif name == "cosine_resnet18": + return cosine_resnet18(pretrained=pretrained,args=args) + elif name == "cosine_resnet32": + return cosine_resnet32() + elif name == "cosine_resnet34": + return cosine_resnet34(pretrained=pretrained,args=args) + elif name == "cosine_resnet50": + return cosine_resnet50(pretrained=pretrained,args=args) + elif name == "resnet18_rep": + return resnet18_rep(pretrained=pretrained,args=args) + elif name == "resnet18_cbam": + return resnet18_cbam(pretrained=pretrained,args=args) + elif name == "resnet34_cbam": + return resnet34_cbam(pretrained=pretrained,args=args) + elif name == "resnet50_cbam": + return resnet50_cbam(pretrained=pretrained,args=args) + + # MEMO benchmark backbone + elif name == 'memo_resnet18': + _basenet, _adaptive_net = get_memo_resnet18() + return _basenet, _adaptive_net + elif name == 'memo_resnet32': + _basenet, _adaptive_net = get_memo_resnet32() + return _basenet, _adaptive_net + + else: + raise NotImplementedError("Unknown type {}".format(name)) + + +class BaseNet(nn.Module): + def __init__(self, args, pretrained): + super(BaseNet, self).__init__() + + self.convnet = get_convnet(args, pretrained) + self.fc = None + + @property + def feature_dim(self): + return self.convnet.out_dim + + def extract_vector(self, x): + return self.convnet(x)["features"] + + def forward(self, x): + x = self.convnet(x) + out = self.fc(x["features"]) + """ + { + 'fmaps': [x_1, x_2, ..., x_n], + 'features': features + 'logits': logits + } + """ + out.update(x) + + return out + + def update_fc(self, nb_classes): + pass + + def generate_fc(self, in_dim, out_dim): + pass + + def copy(self): + return copy.deepcopy(self) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + + return self + + def load_checkpoint(self, args): + if args["init_cls"] == 50: + pkl_name = "{}_{}_{}_B{}_Inc{}".format( + args["dataset"], + args["seed"], + args["convnet_type"], + 0, + args["init_cls"], + ) + checkpoint_name = f"checkpoints/finetune_{pkl_name}_0.pkl" + else: + checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl" + model_infos = torch.load(checkpoint_name) + self.convnet.load_state_dict(model_infos['convnet']) + self.fc.load_state_dict(model_infos['fc']) + test_acc = model_infos['test_acc'] + return test_acc + +class IncrementalNet(BaseNet): + def __init__(self, args, pretrained, gradcam=False): + super().__init__(args, pretrained) + self.gradcam = gradcam + if hasattr(self, "gradcam") and self.gradcam: + self._gradcam_hooks = [None, None] + self.set_gradcam_hook() + + def update_fc(self, nb_classes): + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:nb_output] = weight + fc.bias.data[:nb_output] = bias + + del self.fc + self.fc = fc + + def weight_align(self, increment): + weights = self.fc.weight.data + newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) + oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) + meannew = torch.mean(newnorm) + meanold = torch.mean(oldnorm) + gamma = meanold / meannew + print("alignweights,gamma=", gamma) + self.fc.weight.data[-increment:, :] *= gamma + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + + return fc + + def forward(self, x): + x = self.convnet(x) + out = self.fc(x["features"]) + out.update(x) + if hasattr(self, "gradcam") and self.gradcam: + out["gradcam_gradients"] = self._gradcam_gradients + out["gradcam_activations"] = self._gradcam_activations + + return out + + def unset_gradcam_hook(self): + self._gradcam_hooks[0].remove() + self._gradcam_hooks[1].remove() + self._gradcam_hooks[0] = None + self._gradcam_hooks[1] = None + self._gradcam_gradients, self._gradcam_activations = [None], [None] + + def set_gradcam_hook(self): + self._gradcam_gradients, self._gradcam_activations = [None], [None] + + def backward_hook(module, grad_input, grad_output): + self._gradcam_gradients[0] = grad_output[0] + return None + + def forward_hook(module, input, output): + self._gradcam_activations[0] = output + return None + + self._gradcam_hooks[0] = self.convnet.last_conv.register_backward_hook( + backward_hook + ) + self._gradcam_hooks[1] = self.convnet.last_conv.register_forward_hook( + forward_hook + ) + +class IL2ANet(IncrementalNet): + + def update_fc(self, num_old, num_total, num_aux): + fc = self.generate_fc(self.feature_dim, num_total+num_aux) + if self.fc is not None: + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:num_old] = weight[:num_old] + fc.bias.data[:num_old] = bias[:num_old] + del self.fc + self.fc = fc + +class CosineIncrementalNet(BaseNet): + def __init__(self, args, pretrained, nb_proxy=1): + super().__init__(args, pretrained) + self.nb_proxy = nb_proxy + + def update_fc(self, nb_classes, task_num): + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + if task_num == 1: + fc.fc1.weight.data = self.fc.weight.data + fc.sigma.data = self.fc.sigma.data + else: + prev_out_features1 = self.fc.fc1.out_features + fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data + fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data + fc.sigma.data = self.fc.sigma.data + + del self.fc + self.fc = fc + def generate_fc(self, in_dim, out_dim): + if self.fc is None: + fc = CosineLinear(in_dim, out_dim, self.nb_proxy, to_reduce=True) + else: + prev_out_features = self.fc.out_features // self.nb_proxy + # prev_out_features = self.fc.out_features + fc = SplitCosineLinear( + in_dim, prev_out_features, out_dim - prev_out_features, self.nb_proxy + ) + + return fc + + +class BiasLayer_BIC(nn.Module): + def __init__(self): + super(BiasLayer_BIC, self).__init__() + self.alpha = nn.Parameter(torch.ones(1, requires_grad=True)) + self.beta = nn.Parameter(torch.zeros(1, requires_grad=True)) + + def forward(self, x, low_range, high_range): + ret_x = x.clone() + ret_x[:, low_range:high_range] = ( + self.alpha * x[:, low_range:high_range] + self.beta + ) + return ret_x + + def get_params(self): + return (self.alpha.item(), self.beta.item()) + + +class IncrementalNetWithBias(BaseNet): + def __init__(self, args, pretrained, bias_correction=False): + super().__init__(args, pretrained) + + # Bias layer + self.bias_correction = bias_correction + self.bias_layers = nn.ModuleList([]) + self.task_sizes = [] + + def forward(self, x): + x = self.convnet(x) + out = self.fc(x["features"]) + if self.bias_correction: + logits = out["logits"] + for i, layer in enumerate(self.bias_layers): + logits = layer( + logits, sum(self.task_sizes[:i]), sum(self.task_sizes[: i + 1]) + ) + out["logits"] = logits + + out.update(x) + + return out + + def update_fc(self, nb_classes): + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:nb_output] = weight + fc.bias.data[:nb_output] = bias + + del self.fc + self.fc = fc + + new_task_size = nb_classes - sum(self.task_sizes) + self.task_sizes.append(new_task_size) + self.bias_layers.append(BiasLayer_BIC()) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + + return fc + + def get_bias_params(self): + params = [] + for layer in self.bias_layers: + params.append(layer.get_params()) + + return params + + def unfreeze(self): + for param in self.parameters(): + param.requires_grad = True + + +class DERNet(nn.Module): + def __init__(self, args, pretrained): + super(DERNet, self).__init__() + self.convnet_type = args["convnet_type"] + self.convnets = nn.ModuleList() + self.pretrained = pretrained + self.out_dim = None + self.fc = None + self.aux_fc = None + self.task_sizes = [] + self.args = args + + @property + def feature_dim(self): + if self.out_dim is None: + return 0 + return self.out_dim * len(self.convnets) + + def extract_vector(self, x): + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + return features + + def forward(self, x): + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + + out = self.fc(features) # {logics: self.fc(features)} + + aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"] + + out.update({"aux_logits": aux_logits, "features": features}) + return out + """ + { + 'features': features + 'logits': logits + 'aux_logits':aux_logits + } + """ + + def update_fc(self, nb_classes): + if len(self.convnets) == 0: + self.convnets.append(get_convnet(self.args)) + else: + self.convnets.append(get_convnet(self.args)) + self.convnets[-1].load_state_dict(self.convnets[-2].state_dict()) + + if self.out_dim is None: + self.out_dim = self.convnets[-1].out_dim + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight + fc.bias.data[:nb_output] = bias + + del self.fc + self.fc = fc + + new_task_size = nb_classes - sum(self.task_sizes) + self.task_sizes.append(new_task_size) + + self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + + return fc + + def copy(self): + return copy.deepcopy(self) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + + return self + + def freeze_conv(self): + for param in self.convnets.parameters(): + param.requires_grad = False + self.convnets.eval() + + def weight_align(self, increment): + weights = self.fc.weight.data + newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) + oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) + meannew = torch.mean(newnorm) + meanold = torch.mean(oldnorm) + gamma = meanold / meannew + print("alignweights,gamma=", gamma) + self.fc.weight.data[-increment:, :] *= gamma + + def load_checkpoint(self, args): + checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl" + model_infos = torch.load(checkpoint_name) + assert len(self.convnets) == 1 + self.convnets[0].load_state_dict(model_infos['convnet']) + self.fc.load_state_dict(model_infos['fc']) + test_acc = model_infos['test_acc'] + return test_acc + + +class SimpleCosineIncrementalNet(BaseNet): + def __init__(self, args, pretrained): + super().__init__(args, pretrained) + + def update_fc(self, nb_classes, nextperiod_initialization=None): + fc = self.generate_fc(self.feature_dim, nb_classes).cuda() + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + fc.sigma.data = self.fc.sigma.data + if nextperiod_initialization is not None: + weight = torch.cat([weight.cuda(), nextperiod_initialization.cuda()]) + else: + weight = torch.cat([weight.cuda(), torch.zeros(nb_classes - nb_output, self.feature_dim).cuda()]) + fc.weight = nn.Parameter(weight) + del self.fc + self.fc = fc + def load_checkpoint(self, checkpoint): + self.convnet.load_state_dict(checkpoint["convnet"]) + self.fc.load_state_dict(checkpoint["fc"]) + def generate_fc(self, in_dim, out_dim): + fc = CosineLinear(in_dim, out_dim) + return fc + + +class FOSTERNet(nn.Module): + def __init__(self, args, pretrained): + super(FOSTERNet, self).__init__() + self.convnet_type = args["convnet_type"] + self.convnets = nn.ModuleList() + self.pretrained = pretrained + self.out_dim = None + self.fc = None + self.fe_fc = None + self.task_sizes = [] + self.oldfc = None + self.args = args + + @property + def feature_dim(self): + if self.out_dim is None: + return 0 + return self.out_dim * len(self.convnets) + + def extract_vector(self, x): + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + return features + + def load_checkpoint(self, checkpoint): + if len(self.convnets) == 0: + self.convnets.append(get_convnet(self.args)) + self.convnets[0].load_state_dict(checkpoint["convnet"]) + self.fc.load_state_dict(checkpoint["fc"]) + + def forward(self, x): + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + out = self.fc(features) + fe_logits = self.fe_fc(features[:, -self.out_dim :])["logits"] + + out.update({"fe_logits": fe_logits, "features": features}) + + if self.oldfc is not None: + old_logits = self.oldfc(features[:, : -self.out_dim])["logits"] + out.update({"old_logits": old_logits}) + + out.update({"eval_logits": out["logits"]}) + return out + + def update_fc(self, nb_classes): + self.convnets.append(get_convnet(self.args)) + if self.out_dim is None: + self.out_dim = self.convnets[-1].out_dim + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight + fc.bias.data[:nb_output] = bias + self.convnets[-1].load_state_dict(self.convnets[-2].state_dict()) + + self.oldfc = self.fc + self.fc = fc + new_task_size = nb_classes - sum(self.task_sizes) + self.task_sizes.append(new_task_size) + self.fe_fc = self.generate_fc(self.out_dim, nb_classes) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + return fc + + def copy(self): + return copy.deepcopy(self) + + def copy_fc(self, fc): + weight = copy.deepcopy(fc.weight.data) + bias = copy.deepcopy(fc.bias.data) + n, m = weight.shape[0], weight.shape[1] + self.fc.weight.data[:n, :m] = weight + self.fc.bias.data[:n] = bias + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + return self + + def freeze_conv(self): + for param in self.convnets.parameters(): + param.requires_grad = False + self.convnets.eval() + + def weight_align(self, old, increment, value): + weights = self.fc.weight.data + newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) + oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) + meannew = torch.mean(newnorm) + meanold = torch.mean(oldnorm) + gamma = meanold / meannew * (value ** (old / increment)) + logging.info("align weights, gamma = {} ".format(gamma)) + self.fc.weight.data[-increment:, :] *= gamma + + +class BiasLayer(nn.Module): + def __init__(self): + super(BiasLayer, self).__init__() + self.alpha = nn.Parameter(torch.zeros(1, requires_grad=True)) + self.beta = nn.Parameter(torch.zeros(1, requires_grad=True)) + + def forward(self, x , bias=True): + ret_x = x.clone() + ret_x = (self.alpha+1) * x # + self.beta + if bias: + ret_x = ret_x + self.beta + return ret_x + + def get_params(self): + return (self.alpha.item(), self.beta.item()) + + +class BEEFISONet(nn.Module): + def __init__(self, args, pretrained): + super(BEEFISONet, self).__init__() + self.convnet_type = args["convnet_type"] + self.convnets = nn.ModuleList() + self.pretrained = pretrained + self.out_dim = None + self.old_fc = None + self.new_fc = None + self.task_sizes = [] + self.forward_prototypes = None + self.backward_prototypes = None + self.args = args + self.biases = nn.ModuleList() + + @property + def feature_dim(self): + if self.out_dim is None: + return 0 + return self.out_dim * len(self.convnets) + + def extract_vector(self, x): + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + return features + + def forward(self, x): + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + + if self.old_fc is None: + fc = self.new_fc + out = fc(features) + else: + ''' + merge the weights + ''' + new_task_size = self.task_sizes[-1] + fc_weight = torch.cat([self.old_fc.weight,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0) + new_fc_weight = self.new_fc.weight + new_fc_bias = self.new_fc.bias + for i in range(len(self.task_sizes)-2,-1,-1): + new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])],new_fc_weight],dim=0) + new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias]) + fc_weight = torch.cat([fc_weight,new_fc_weight],dim=1) + fc_bias = torch.cat([self.old_fc.bias,torch.zeros(new_task_size).cuda()]) + fc_bias+=new_fc_bias + logits = features@fc_weight.permute(1,0)+fc_bias + out = {"logits":logits} + + new_fc_weight = self.new_fc.weight + new_fc_bias = self.new_fc.bias + for i in range(len(self.task_sizes)-2,-1,-1): + new_fc_weight = torch.cat([self.backward_prototypes.weight[i].unsqueeze(0),new_fc_weight],dim=0) + new_fc_bias = torch.cat([self.backward_prototypes.bias[i].unsqueeze(0), new_fc_bias]) + out["train_logits"] = features[:,-self.out_dim:]@new_fc_weight.permute(1,0)+new_fc_bias + out.update({"eval_logits": out["logits"],"energy_logits":self.forward_prototypes(features[:,-self.out_dim:])["logits"]}) + return out + + def update_fc_before(self, nb_classes): + new_task_size = nb_classes - sum(self.task_sizes) + self.biases = nn.ModuleList([BiasLayer() for i in range(len(self.task_sizes))]) + self.convnets.append(get_convnet(self.args)) + if self.out_dim is None: + self.out_dim = self.convnets[-1].out_dim + if self.new_fc is not None: + self.fe_fc = self.generate_fc(self.out_dim, nb_classes) + self.backward_prototypes = self.generate_fc(self.out_dim,len(self.task_sizes)) + self.convnets[-1].load_state_dict(self.convnets[0].state_dict()) + self.forward_prototypes = self.generate_fc(self.out_dim, nb_classes) + self.new_fc = self.generate_fc(self.out_dim,new_task_size) + self.task_sizes.append(new_task_size) + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + return fc + + def update_fc_after(self): + if self.old_fc is not None: + old_fc = self.generate_fc(self.feature_dim, sum(self.task_sizes)) + new_task_size = self.task_sizes[-1] + old_fc.weight.data = torch.cat([self.old_fc.weight.data,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0) + new_fc_weight = self.new_fc.weight.data + new_fc_bias = self.new_fc.bias.data + for i in range(len(self.task_sizes)-2,-1,-1): + new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight.data[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])], new_fc_weight],dim=0) + new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias.data[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias]) + old_fc.weight.data = torch.cat([old_fc.weight.data,new_fc_weight],dim=1) + old_fc.bias.data = torch.cat([self.old_fc.bias.data,torch.zeros(new_task_size).cuda()]) + old_fc.bias.data+=new_fc_bias + self.old_fc = old_fc + else: + self.old_fc = self.new_fc + + def copy(self): + return copy.deepcopy(self) + + def copy_fc(self, fc): + weight = copy.deepcopy(fc.weight.data) + bias = copy.deepcopy(fc.bias.data) + n, m = weight.shape[0], weight.shape[1] + self.fc.weight.data[:n, :m] = weight + self.fc.bias.data[:n] = bias + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + return self + + def freeze_conv(self): + for param in self.convnets.parameters(): + param.requires_grad = False + self.convnets.eval() + + def weight_align(self, old, increment, value): + weights = self.fc.weight.data + newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) + oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) + meannew = torch.mean(newnorm) + meanold = torch.mean(oldnorm) + gamma = meanold / meannew * (value ** (old / increment)) + logging.info("align weights, gamma = {} ".format(gamma)) + self.fc.weight.data[-increment:, :] *= gamma + + +class AdaptiveNet(nn.Module): + def __init__(self, args, pretrained): + super(AdaptiveNet, self).__init__() + self.convnet_type = args["convnet_type"] + self.TaskAgnosticExtractor , _network = get_convnet(args, pretrained) #Generalized blocks + self.TaskAgnosticExtractor.train() + self.AdaptiveExtractors = nn.ModuleList() #Specialized Blocks + self.AdaptiveExtractors.append(_network) + self.pretrained=pretrained + if args["backbone"] != None and pretrained == True: + self.load_checkpoint(args) + self.out_dim=None + self.fc = None + self.aux_fc=None + self.task_sizes = [] + self.args=args + + @property + def feature_dim(self): + if self.out_dim is None: + return 0 + return self.out_dim*len(self.AdaptiveExtractors) + + def extract_vector(self, x): + base_feature_map = self.TaskAgnosticExtractor(x) + features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors] + features = torch.cat(features, 1) + return features + + def forward(self, x): + base_feature_map = self.TaskAgnosticExtractor(x) + features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors] + features = torch.cat(features, 1) + out=self.fc(features) #{logits: self.fc(features)} + + aux_logits=self.aux_fc(features[:,-self.out_dim:])["logits"] + + out.update({"aux_logits":aux_logits,"features":features}) + out.update({"base_features":base_feature_map}) + return out + + ''' + { + 'features': features + 'logits': logits + 'aux_logits':aux_logits + } + ''' + + def update_fc(self,nb_classes): + _ , _new_extractor = get_convnet(self.args) + if len(self.AdaptiveExtractors)==0: + self.AdaptiveExtractors.append(_new_extractor) + else: + self.AdaptiveExtractors.append(_new_extractor) + self.AdaptiveExtractors[-1].load_state_dict(self.AdaptiveExtractors[-2].state_dict()) + + if self.out_dim is None: + logging.info(self.AdaptiveExtractors[-1]) + self.out_dim=self.AdaptiveExtractors[-1].feature_dim + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight + fc.bias.data[:nb_output] = bias + + del self.fc + self.fc = fc + + new_task_size = nb_classes - sum(self.task_sizes) + self.task_sizes.append(new_task_size) + self.aux_fc=self.generate_fc(self.out_dim,new_task_size+1) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + return fc + + def copy(self): + return copy.deepcopy(self) + + def weight_align(self, increment): + weights=self.fc.weight.data + newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1)) + oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1)) + meannew=torch.mean(newnorm) + meanold=torch.mean(oldnorm) + gamma=meanold/meannew + print('alignweights,gamma=',gamma) + self.fc.weight.data[-increment:,:]*=gamma + + def load_checkpoint(self, args): + checkpoint_name = args["backbone"] + model_infos = torch.load(checkpoint_name) + model_dict = model_infos['convnet'] + assert len(self.AdaptiveExtractors) == 1 + + base_state_dict = self.TaskAgnosticExtractor.state_dict() + adap_state_dict = self.AdaptiveExtractors[0].state_dict() + + pretrained_base_dict = { + k:v + for k, v in model_dict.items() + if k in base_state_dict + } + + pretrained_adap_dict = { + k:v + for k, v in model_dict.items() + if k in adap_state_dict + } + + base_state_dict.update(pretrained_base_dict) + adap_state_dict.update(pretrained_adap_dict) + + self.TaskAgnosticExtractor.load_state_dict(base_state_dict) + self.AdaptiveExtractors[0].load_state_dict(adap_state_dict) + #self.fc.load_state_dict(model_infos['fc']) + test_acc = model_infos['test_acc'] + return test_acc diff --git a/utils/ops.py b/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..66dcb775caed0ebd9105ac530acd3e70e138d72e --- /dev/null +++ b/utils/ops.py @@ -0,0 +1,121 @@ +from PIL import Image, ImageEnhance, ImageOps +import random +import torch +import numpy as np +class Cutout(object): + def __init__(self, n_holes, length): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + h = img.size(1) + w = img.size(2) + + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + return img + +class ShearX(object): + def __init__(self, fillcolor=(128, 128, 128)): + self.fillcolor = fillcolor + + def __call__(self, x, magnitude): + return x.transform( + x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, fillcolor=self.fillcolor) + + +class ShearY(object): + def __init__(self, fillcolor=(128, 128, 128)): + self.fillcolor = fillcolor + + def __call__(self, x, magnitude): + return x.transform( + x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, fillcolor=self.fillcolor) + + +class TranslateX(object): + def __init__(self, fillcolor=(128, 128, 128)): + self.fillcolor = fillcolor + + def __call__(self, x, magnitude): + return x.transform( + x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0), + fillcolor=self.fillcolor) + + +class TranslateY(object): + def __init__(self, fillcolor=(128, 128, 128)): + self.fillcolor = fillcolor + + def __call__(self, x, magnitude): + return x.transform( + x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])), + fillcolor=self.fillcolor) + + +class Rotate(object): + def __call__(self, x, magnitude): + rot = x.convert("RGBA").rotate(magnitude * random.choice([-1, 1])) + return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode) + + +class Color(object): + def __call__(self, x, magnitude): + return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1])) + + +class Posterize(object): + def __call__(self, x, magnitude): + return ImageOps.posterize(x, magnitude) + + +class Solarize(object): + def __call__(self, x, magnitude): + return ImageOps.solarize(x, magnitude) + + +class Contrast(object): + def __call__(self, x, magnitude): + return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1])) + + +class Sharpness(object): + def __call__(self, x, magnitude): + return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1])) + + +class Brightness(object): + def __call__(self, x, magnitude): + return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1])) + + +class AutoContrast(object): + def __call__(self, x, magnitude): + return ImageOps.autocontrast(x) + + +class Equalize(object): + def __call__(self, x, magnitude): + return ImageOps.equalize(x) + + +class Invert(object): + def __call__(self, x, magnitude): + return ImageOps.invert(x) diff --git a/utils/rl_utils/ddpg.py b/utils/rl_utils/ddpg.py new file mode 100644 index 0000000000000000000000000000000000000000..555e46645d91bc08e7e90e6b3d262248a8d2900f --- /dev/null +++ b/utils/rl_utils/ddpg.py @@ -0,0 +1,206 @@ +import logging +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + + +class PolicyNet(torch.nn.Module): + def __init__(self, state_dim, hidden_dim, action_dim, action_bound): + super(PolicyNet, self).__init__() + self.fc1 = torch.nn.Linear(state_dim, hidden_dim) + self.fc2 = torch.nn.Linear(hidden_dim, action_dim) + self.action_bound = action_bound + + def forward(self, x): + x = F.relu(self.fc1(x)) + return torch.tanh(self.fc2(x)) * self.action_bound + + +class RMMPolicyNet(torch.nn.Module): + def __init__(self, state_dim, hidden_dim, action_dim): + super(RMMPolicyNet, self).__init__() + self.fc1 = nn.Sequential( + nn.Linear(state_dim, hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, action_dim), + ) + self.fc2 = nn.Sequential( + nn.Linear(state_dim+action_dim, hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, action_dim), + ) + def forward(self, x): + a1 = torch.sigmoid(self.fc1(x)) + x = torch.cat([x,a1],dim=1) + a2 = torch.tanh(self.fc2(x)) + return torch.cat([a1,a2],dim=1) + +class QValueNet(torch.nn.Module): + def __init__(self, state_dim, hidden_dim, action_dim): + super(QValueNet, self).__init__() + self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim) + self.fc2 = torch.nn.Linear(hidden_dim, 1) + + def forward(self, x, a): + cat = torch.cat([x, a], dim=1) + x = F.relu(self.fc1(cat)) + return self.fc2(x) + + +class TwoLayerFC(torch.nn.Module): + def __init__( + self, num_in, num_out, hidden_dim, activation=F.relu, out_fn=lambda x: x + ): + super().__init__() + self.fc1 = nn.Linear(num_in, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim, num_out) + + self.activation = activation + self.out_fn = out_fn + + def forward(self, x): + x = self.activation(self.fc1(x)) + x = self.activation(self.fc2(x)) + x = self.out_fn(self.fc3(x)) + return x + + +class DDPG: + """DDPG algo""" + + def __init__( + self, + num_in_actor, + num_out_actor, + num_in_critic, + hidden_dim, + discrete, + action_bound, + sigma, + actor_lr, + critic_lr, + tau, + gamma, + device, + use_rmm=True, + ): + + out_fn = (lambda x: x) if discrete else (lambda x: torch.tanh(x) * action_bound) + + if use_rmm: + self.actor = RMMPolicyNet( + num_in_actor, + hidden_dim, + num_out_actor, + ).to(device) + self.target_actor = RMMPolicyNet( + num_in_actor, + hidden_dim, + num_out_actor, + ).to(device) + else: + self.actor = TwoLayerFC( + num_in_actor, + num_out_actor, + hidden_dim, + activation=F.relu, + out_fn=out_fn, + ).to(device) + self.target_actor = TwoLayerFC( + num_in_actor, + num_out_actor, + hidden_dim, + activation=F.relu, + out_fn=out_fn, + ).to(device) + + self.critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) + self.target_critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) + self.target_critic.load_state_dict(self.critic.state_dict()) + self.target_actor.load_state_dict(self.actor.state_dict()) + self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) + self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) + self.gamma = gamma + self.sigma = sigma + self.action_bound = action_bound + self.tau = tau + self.action_dim = num_out_actor + self.device = device + + def take_action(self, state): + state = torch.tensor(np.expand_dims(state,0), dtype=torch.float).to(self.device) + action = self.actor(state)[0].detach().cpu().numpy() + + action = action + self.sigma * np.random.randn(self.action_dim) + action[0]=np.clip(action[0],0,1) + action[1]=np.clip(action[1],-1,1) + return action + def save_state_dict(self,name): + dicts = { + "critic":self.critic.state_dict(), + "target_critic":self.target_critic.state_dict(), + "actor":self.actor.state_dict(), + "target_actor":self.target_actor.state_dict() + } + torch.save(dicts,name) + def load_state_dict(self,name): + dicts = torch.load(name) + self.critic.load_state_dict(dicts["critic"]) + self.target_critic.load_state_dict(dicts["target_critic"]) + self.actor.load_state_dict(dicts["actor"]) + self.target_actor.load_state_dict(dicts["target_actor"]) + def soft_update(self, net, target_net): + for param_target, param in zip(target_net.parameters(), net.parameters()): + param_target.data.copy_( + param_target.data * (1.0 - self.tau) + param.data * self.tau + ) + + def update(self, transition_dict): + states = torch.tensor(transition_dict["states"], dtype=torch.float).to( + self.device + ) + actions = ( + torch.tensor(transition_dict["actions"], dtype=torch.float) + .to(self.device) + ) + rewards = ( + torch.tensor(transition_dict["rewards"], dtype=torch.float) + .view(-1, 1) + .to(self.device) + ) + next_states = torch.tensor( + transition_dict["next_states"], dtype=torch.float + ).to(self.device) + dones = ( + torch.tensor(transition_dict["dones"], dtype=torch.float) + .view(-1, 1) + .to(self.device) + ) + + next_q_values = self.target_critic( + torch.cat([next_states, self.target_actor(next_states)], dim=1) + ) + q_targets = rewards + self.gamma * next_q_values * (1 - dones) + critic_loss = torch.mean( + F.mse_loss( + self.critic(torch.cat([states, actions], dim=1)), + q_targets, + ) + ) + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + actor_loss = -torch.mean( + self.critic( + torch.cat([states, self.actor(states)], dim=1) + ) + ) + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + logging.info(f"update DDPG: actor loss {actor_loss.item():.3f}, critic loss {critic_loss.item():.3f}, ") + self.soft_update(self.actor, self.target_actor) # soft-update the target policy net + self.soft_update(self.critic, self.target_critic) # soft-update the target Q value net diff --git a/utils/rl_utils/rl_utils.py b/utils/rl_utils/rl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33799bbd129890e2a8b14c48ff5086bfd51f7e97 --- /dev/null +++ b/utils/rl_utils/rl_utils.py @@ -0,0 +1,20 @@ +from tqdm import tqdm +import numpy as np +import torch +import collections +import random + +class ReplayBuffer: + def __init__(self, capacity): + self.buffer = collections.deque(maxlen=capacity) + + def add(self, state, action, reward, next_state, done): + self.buffer.append((state, action, reward, next_state, done)) + + def sample(self, batch_size): + transitions = random.sample(self.buffer, batch_size) + state, action, reward, next_state, done = zip(*transitions) + return np.array(state), np.array(action), reward, np.array(next_state), done + + def size(self): + return len(self.buffer) \ No newline at end of file diff --git a/utils/toolkit.py b/utils/toolkit.py new file mode 100644 index 0000000000000000000000000000000000000000..a9792e5180b445bfbf1ce1cf923f213d3d852b4b --- /dev/null +++ b/utils/toolkit.py @@ -0,0 +1,116 @@ +import os +import numpy as np +import torch +import json +from enum import Enum + +class ConfigEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, type): + return {'$class': o.__module__ + "." + o.__name__} + elif isinstance(o, Enum): + return { + '$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name + } + elif callable(o): + return { + '$function': o.__module__ + "." + o.__name__ + } + return json.JSONEncoder.default(self, o) + +def count_parameters(model, trainable=False): + if trainable: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + return sum(p.numel() for p in model.parameters()) + + +def tensor2numpy(x): + return x.cpu().data.numpy() if x.is_cuda else x.data.numpy() + + +def target2onehot(targets, n_classes): + onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) + onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0) + return onehot + + +def makedirs(path): + if not os.path.exists(path): + os.makedirs(path) + + +def accuracy(y_pred, y_true, nb_old, increment=10): + assert len(y_pred) == len(y_true), "Data length error." + all_acc = {} + all_acc["total"] = np.around( + (y_pred == y_true).sum() * 100 / len(y_true), decimals=2 + ) + + # Grouped accuracy + for class_id in range(0, np.max(y_true), increment): + idxes = np.where( + np.logical_and(y_true >= class_id, y_true < class_id + increment) + )[0] + if increment == 1: + label = "{}".format( + str(class_id).rjust(2, "0") + ) + else: + label = "{}-{}".format( + str(class_id).rjust(2, "0"), str(class_id + increment - 1).rjust(2, "0") + ) + all_acc[label] = np.around( + (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2 + ) + + # Old accuracy + idxes = np.where(y_true < nb_old)[0] + all_acc["old"] = ( + 0 + if len(idxes) == 0 + else np.around( + (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2 + ) + ) + + # New accuracy + idxes = np.where(y_true >= nb_old)[0] + all_acc["new"] = np.around( + (y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=2 + ) + + return all_acc + + +def split_images_labels(imgs, start_index = 0): + # split trainset.imgs in ImageFolder + images = [] + labels = [] + for item in imgs: + images.append(item[0]) + labels.append(item[1] + start_index) + return np.array(images), np.array(labels) + +def save_fc(args, model): + _path = os.path.join(args['logfilename'], "fc.pt") + if len(args['device']) > 1: + fc_weight = model._network.fc.weight.data + else: + fc_weight = model._network.fc.weight.data.cpu() + torch.save(fc_weight, _path) + + _save_dir = os.path.join(f"./results/fc_weights/{args['prefix']}") + os.makedirs(_save_dir, exist_ok=True) + _save_path = os.path.join(_save_dir, f"{args['csv_name']}.csv") + with open(_save_path, "a+") as f: + f.write(f"{args['time_str']},{args['model_name']},{_path} \n") + + +def save_model(args, model): + #used in PODNet + _path = os.path.join(args['logfilename'], "model.pt") + if len(args['device']) > 1: + weight = model._network + else: + weight = model._network.cpu() + torch.save(weight, _path) \ No newline at end of file