Spaces:
Runtime error
Runtime error
Upload 9 files
Browse files- .gitignore +140 -0
- LICENSE +21 -0
- README.md +166 -12
- app.py +146 -0
- requirements.txt +2 -1
- save_roc.py +153 -0
- setup.py +57 -0
- stoma_clip.pt +3 -0
- vis_all_model_roc.py +71 -0
.gitignore
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
data/
|
| 133 |
+
microsoft/
|
| 134 |
+
ckpt/
|
| 135 |
+
# *.jpg
|
| 136 |
+
*.png
|
| 137 |
+
logs/
|
| 138 |
+
*.json
|
| 139 |
+
evaluation*/
|
| 140 |
+
*.ipynb
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Weixiong Lin
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,166 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PMC-CLIP
|
| 2 |
+
|
| 3 |
+
[](https://colab.research.google.com/drive/1P7uyzK_Mhu1YyMeRrrRY_e3NpkNBOI4L?usp=sharing)
|
| 4 |
+
[](https://huggingface.co/datasets/axiong/pmc_oa)
|
| 5 |
+
|
| 6 |
+
The dataset and checkpoint is available at [Huggingface](https://huggingface.co/datasets/axiong/pmc-oa), [Baidu Cloud](https://pan.baidu.com/s/1mD51oOYbIOqDJSeiPNaCCg)(key: 3iqf).
|
| 7 |
+
|
| 8 |
+
📢 We provide the extracted image encoder and text encoder checkpoint in [Huggingface](https://huggingface.co/datasets/axiong/pmc-oa), and a quick start demo on how to use them in encoding image and text input. Check this [notebook](https://colab.research.google.com/drive/1P7uyzK_Mhu1YyMeRrrRY_e3NpkNBOI4L?usp=sharing)!
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
- [PMC-CLIP](#pmc-clip)
|
| 12 |
+
- [Quick Start Inference](#quick-start-inference)
|
| 13 |
+
- [Train and Evaluation](#train-and-evaluation)
|
| 14 |
+
- [1. Create Environment](#1-create-environment)
|
| 15 |
+
- [2. Prepare Dataset](#2-prepare-dataset)
|
| 16 |
+
- [3. Training](#3-training)
|
| 17 |
+
- [4. Evaluation](#4-evaluation)
|
| 18 |
+
- [Acknowledgement](#acknowledgement)
|
| 19 |
+
- [Contribution](#contribution)
|
| 20 |
+
- [TODO](#todo)
|
| 21 |
+
- [Cite](#cite)
|
| 22 |
+
|
| 23 |
+
## Quick Start Inference
|
| 24 |
+
|
| 25 |
+
We offer a quick start demo on how to use the image and text encoder of PMC-CLIP. Check this [notebook](https://colab.research.google.com/drive/1P7uyzK_Mhu1YyMeRrrRY_e3NpkNBOI4L?usp=sharing)!
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Train and Evaluation
|
| 29 |
+
|
| 30 |
+
Repo Structure
|
| 31 |
+
```bash
|
| 32 |
+
src/:
|
| 33 |
+
|--setup.py
|
| 34 |
+
|--pmc_clip/
|
| 35 |
+
| |--loss/
|
| 36 |
+
| |--model/: PMC-CLIP model and variants
|
| 37 |
+
| |--model_configs/
|
| 38 |
+
| |--factory.py: Create model according to configs
|
| 39 |
+
| |--transform.py: data augmentation
|
| 40 |
+
|--training/
|
| 41 |
+
| |--main.py
|
| 42 |
+
| |--scheduler.py: Learning rate scheduler
|
| 43 |
+
| |--train.py
|
| 44 |
+
| |--evaluate.py
|
| 45 |
+
| |--data.py
|
| 46 |
+
| |--params.py
|
| 47 |
+
docs/: project pages
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### 1. Create Environment
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
conda create -n pmc_clip python=3.8
|
| 54 |
+
conda activate pmc_clip
|
| 55 |
+
|
| 56 |
+
pip install -r requirements.txt
|
| 57 |
+
# pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
|
| 58 |
+
|
| 59 |
+
python setup.py develop # install pmc_clip with dev mode
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### 2. Prepare Dataset
|
| 63 |
+
|
| 64 |
+
Download from [Huggingface](https://huggingface.co/datasets/axiong/pmc-oa), [Baidu Cloud](https://pan.baidu.com/s/1mD51oOYbIOqDJSeiPNaCCg)(key: 3iqf).
|
| 65 |
+
Or follow the [Pipeline of PMC-OA Development](https://github.com/WeixiongLin/Build-PMC-OA) if you want to start from scratch.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
### 3. Training
|
| 69 |
+
|
| 70 |
+
Single GPU
|
| 71 |
+
```bash
|
| 72 |
+
python -m training.main \
|
| 73 |
+
--dataset-type "csv" --csv-separator "," --save-frequency 5 \
|
| 74 |
+
--report-to tensorboard \
|
| 75 |
+
--train-data="path/to/train.csv" --val-data="path/to/valid.csv" \
|
| 76 |
+
--csv-img-key image --csv-caption-key caption \
|
| 77 |
+
--warmup 500 --batch-size=8 --lr=1e-4 --wd=0.1 --epochs=100 --workers=8 \
|
| 78 |
+
--model RN50_fusion4 --hugging-face --mlm --crop-scale 0.5
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Multi GPU
|
| 82 |
+
```bash
|
| 83 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 --rdzv_endpoint=$HOSTE_NODE_ADDR -m training.main \
|
| 84 |
+
--dataset-type "csv" --csv-separator "," --save-frequency 5 \
|
| 85 |
+
--report-to tensorboard \
|
| 86 |
+
--train-data="path/to/train.csv" --val-data="path/to/valid.csv" \
|
| 87 |
+
--csv-img-key image --csv-caption-key caption \
|
| 88 |
+
--warmup 500 --batch-size=128 --lr=1e-4 --wd=0.1 --epochs=100 --workers=8 \
|
| 89 |
+
--model RN50_fusion4 --hugging-face --mlm --crop-scale 0.5
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
<div class="third">
|
| 93 |
+
<img src="docs/resources/train_loss.png" style="height:200px">
|
| 94 |
+
<img src="docs/resources/val_i2t@1.png" style="height:200px">
|
| 95 |
+
<img src="docs/resources/val_t2i@1.png" style="height:200px">
|
| 96 |
+
</div>
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
### 4. Evaluation
|
| 100 |
+
Load checkpoint and eval on 2k samples from testset.
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
python -m training.main \
|
| 104 |
+
--dataset-type "csv" --csv-separator "," --report-to tensorboard \
|
| 105 |
+
--val-data="path/to/test.csv" \
|
| 106 |
+
--csv-img-key image --csv-caption-key caption \
|
| 107 |
+
--batch-size=32 --workers=8 \
|
| 108 |
+
--model RN50_fusion4 --hugging-face --mlm --crop-scale 0.1 \
|
| 109 |
+
--resume /path/to/checkpoint.pt \
|
| 110 |
+
--test-2000
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Also we provide automatic ways to load model weights from huggingface repo.
|
| 114 |
+
|
| 115 |
+
| Model | URL |
|
| 116 |
+
| --- | --- |
|
| 117 |
+
| PMC_CLIP:beta | https://huggingface.co/datasets/axiong/pmc_oa_beta/blob/main/checkpoint.pt |
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
Take PMC_CLIP:beta checkpoint as an example:
|
| 121 |
+
```bash
|
| 122 |
+
python -m training.main \
|
| 123 |
+
--dataset-type "csv" --csv-separator "," --report-to tensorboard \
|
| 124 |
+
--val-data="path/to/test.csv" \
|
| 125 |
+
--csv-img-key image --csv-caption-key caption \
|
| 126 |
+
--batch-size=32 --workers=8 \
|
| 127 |
+
--model RN50_fusion4 --hugging-face --mlm --crop-scale 0.1 \
|
| 128 |
+
--resume "PMC_CLIP:beta" \
|
| 129 |
+
--test-2000
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
## Acknowledgement
|
| 134 |
+
The code is based on [OpenCLIP](https://github.com/mlfoundations/open_clip) and [M3AE](https://github.com/zhjohnchan/M3AE). We thank the authors for their open-sourced code and encourage users to cite their works when applicable.
|
| 135 |
+
|
| 136 |
+
Note that our code don't supported tools like horovod, wandb in OpenCLIP. But we keep the code from OpenCLIP for consistency.
|
| 137 |
+
|
| 138 |
+
## Contribution
|
| 139 |
+
Please raise an issue if you need help, any contributions are welcomed.
|
| 140 |
+
|
| 141 |
+
## TODO
|
| 142 |
+
|
| 143 |
+
* [ ] Compatibility testing on more env settings
|
| 144 |
+
* [ ] Support for horovod, wandb
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
## Cite
|
| 148 |
+
```bash
|
| 149 |
+
@article{lin2023pmc,
|
| 150 |
+
title={PMC-CLIP: Contrastive Language-Image Pre-training using Biomedical Documents},
|
| 151 |
+
author={Lin, Weixiong and Zhao, Ziheng and Zhang, Xiaoman and Wu, Chaoyi and Zhang, Ya and Wang, Yanfeng and Xie, Weidi},
|
| 152 |
+
journal={arXiv preprint arXiv:2303.07240},
|
| 153 |
+
year={2023}
|
| 154 |
+
}
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
The paper has been accepted by MICCAI 2023.
|
| 158 |
+
```bash
|
| 159 |
+
@inproceedings{lin2023pmc,
|
| 160 |
+
title={Pmc-clip: Contrastive language-image pre-training using biomedical documents},
|
| 161 |
+
author={Lin, Weixiong and Zhao, Ziheng and Zhang, Xiaoman and Wu, Chaoyi and Zhang, Ya and Wang, Yanfeng and Xie, Weidi},
|
| 162 |
+
booktitle={MICCAI},
|
| 163 |
+
year={2023}
|
| 164 |
+
}
|
| 165 |
+
```
|
| 166 |
+
|
app.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# ========== 1. Import project modules ==========
|
| 8 |
+
try:
|
| 9 |
+
from stoma_clip import pmc_clip
|
| 10 |
+
from stoma_clip.pmc_clip.factory import _rescan_model_configs
|
| 11 |
+
from stoma_clip.training.fusion_method import convert_model_to_cls
|
| 12 |
+
from stoma_clip.training.dataset.utils import encode_mlm
|
| 13 |
+
except ImportError as e:
|
| 14 |
+
print(f"Error importing Stoma-CLIP modules: {e}")
|
| 15 |
+
|
| 16 |
+
# ========== 2. Model Configuration and Loading ==========
|
| 17 |
+
LABEL_MAP = {
|
| 18 |
+
"Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2,
|
| 19 |
+
"Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6,
|
| 20 |
+
"Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10
|
| 21 |
+
}
|
| 22 |
+
REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
|
| 23 |
+
NUM_CLASSES = len(LABEL_MAP)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Args:
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.model = "RN50_fusion4"
|
| 29 |
+
self.pretrained = "stoma_clip.pt"
|
| 30 |
+
self.num_classes = NUM_CLASSES
|
| 31 |
+
self.mlm = True
|
| 32 |
+
self.crop_scale = 0.9
|
| 33 |
+
self.context_length = 77
|
| 34 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
print(f"Using device: {self.device}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
args = Args()
|
| 39 |
+
MODEL = None
|
| 40 |
+
PREPROCESS = None
|
| 41 |
+
TOKENIZER = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_model():
|
| 45 |
+
"""Load model once when Gradio starts."""
|
| 46 |
+
global MODEL, PREPROCESS, TOKENIZER
|
| 47 |
+
if MODEL is not None:
|
| 48 |
+
return MODEL, PREPROCESS, TOKENIZER
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
_rescan_model_configs()
|
| 52 |
+
model, _, preprocess = pmc_clip.create_model_and_transforms(args)
|
| 53 |
+
model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention')
|
| 54 |
+
model.to(args.device).eval()
|
| 55 |
+
|
| 56 |
+
state_dict = torch.load(args.pretrained, map_location='cpu')
|
| 57 |
+
state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()}
|
| 58 |
+
model.load_state_dict(state_dict_clean)
|
| 59 |
+
tokenizer = model.tokenizer
|
| 60 |
+
|
| 61 |
+
MODEL = model
|
| 62 |
+
PREPROCESS = preprocess
|
| 63 |
+
TOKENIZER = tokenizer
|
| 64 |
+
print("Stoma-CLIP Model loaded successfully!")
|
| 65 |
+
return MODEL, PREPROCESS, TOKENIZER
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Error during model loading: {e}")
|
| 69 |
+
MODEL = None
|
| 70 |
+
raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ========== 3. Inference Function ==========
|
| 74 |
+
def predict_stoma_clip(image: Image.Image, caption: str):
|
| 75 |
+
if MODEL is None:
|
| 76 |
+
return "Model Loading Failed", {}
|
| 77 |
+
|
| 78 |
+
image = image.convert("RGB")
|
| 79 |
+
model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER
|
| 80 |
+
device = args.device
|
| 81 |
+
|
| 82 |
+
image_tensor = preprocess(image).unsqueeze(0).to(device)
|
| 83 |
+
|
| 84 |
+
mask_token, pad_token = '[MASK]', '[PAD]'
|
| 85 |
+
vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens]
|
| 86 |
+
|
| 87 |
+
bert_input, bert_label = encode_mlm(
|
| 88 |
+
caption=caption,
|
| 89 |
+
vocab=vocab,
|
| 90 |
+
mask_token=mask_token,
|
| 91 |
+
pad_token=pad_token,
|
| 92 |
+
ratio=0.0,
|
| 93 |
+
tokenizer=tokenizer,
|
| 94 |
+
args=args,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label}
|
| 99 |
+
outputs = model(inputs)
|
| 100 |
+
probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
|
| 101 |
+
predicted_class_idx = torch.argmax(outputs, dim=1).item()
|
| 102 |
+
|
| 103 |
+
predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown")
|
| 104 |
+
probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
|
| 105 |
+
|
| 106 |
+
return predicted_class_name, probability_distribution
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ========== 4. Gradio Interface Setup ==========
|
| 110 |
+
try:
|
| 111 |
+
load_model()
|
| 112 |
+
print("模型已在 Gradio 启动前成功加载。")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"致命错误:模型未能加载。Gradio 界面将无法运行。{e}")
|
| 115 |
+
|
| 116 |
+
image_input = gr.Image(type="pil", label="上传造口图片")
|
| 117 |
+
caption_input = gr.Textbox(label="输入造口描述文本 (例如: Exudate, epidermal breakdown, ...)")
|
| 118 |
+
predicted_label_output = gr.Textbox(label="预测类别")
|
| 119 |
+
prob_output = gr.Label(label="类别概率分布")
|
| 120 |
+
|
| 121 |
+
# Find example path for Gradio demo (Note: In the deployed Space, these paths should be relative to the root)
|
| 122 |
+
try:
|
| 123 |
+
example_path_1 = "demo/Irritant_dermatitis.jpg"
|
| 124 |
+
example_path_2 = "demo/Folliculitis.jpg"
|
| 125 |
+
|
| 126 |
+
examples_list = []
|
| 127 |
+
if os.path.exists(example_path_1):
|
| 128 |
+
examples_list.append(
|
| 129 |
+
[example_path_1, "Exudate, epidermal breakdown, irregular erythema, pain, confined to contact areas"])
|
| 130 |
+
elif os.path.exists(example_path_2):
|
| 131 |
+
examples_list.append([example_path_2, "Erythema, papules, pustules confined to hair follicles"])
|
| 132 |
+
except Exception:
|
| 133 |
+
examples_list = []
|
| 134 |
+
|
| 135 |
+
iface = gr.Interface(
|
| 136 |
+
fn=predict_stoma_clip,
|
| 137 |
+
inputs=[image_input, caption_input],
|
| 138 |
+
outputs=[predicted_label_output, prob_output],
|
| 139 |
+
title="🧪 Stoma-CLIP 分类 API 原型 (Gradio)",
|
| 140 |
+
description="请上传造口图片并输入临床描述,模型将预测最可能的皮肤并发症类别。",
|
| 141 |
+
examples=examples_list,
|
| 142 |
+
allow_flagging="never"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
iface.launch()
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
torchvision
|
| 3 |
transformers
|
| 4 |
tqdm
|
|
@@ -10,3 +10,4 @@ braceexpand
|
|
| 10 |
webdataset
|
| 11 |
jsonlines
|
| 12 |
tensorboard
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
torchvision
|
| 3 |
transformers
|
| 4 |
tqdm
|
|
|
|
| 10 |
webdataset
|
| 11 |
jsonlines
|
| 12 |
tensorboard
|
| 13 |
+
matplotlib
|
save_roc.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
from sklearn.metrics import roc_curve, auc, confusion_matrix, roc_auc_score
|
| 9 |
+
from sklearn.preprocessing import label_binarize
|
| 10 |
+
import json
|
| 11 |
+
import sys
|
| 12 |
+
sys.path.append('.')
|
| 13 |
+
import pmc_clip
|
| 14 |
+
from training.params import parse_args
|
| 15 |
+
from training.data import PmcDataset
|
| 16 |
+
from training.fusion_method import convert_model_to_cls
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# 标签映射
|
| 20 |
+
LABEL_MAP = {
|
| 21 |
+
"Irritant dermatitis": 0,
|
| 22 |
+
"Allergic contact dermatitis": 1,
|
| 23 |
+
"Mechanical injury": 2,
|
| 24 |
+
"Folliculitis": 3,
|
| 25 |
+
"Fungal infection": 4,
|
| 26 |
+
"Skin hyperplasia": 5,
|
| 27 |
+
"Parastomal varices": 6,
|
| 28 |
+
"Urate crystals": 7,
|
| 29 |
+
"Cancerous metastasis": 8,
|
| 30 |
+
"Pyoderma gangrenosum": 9,
|
| 31 |
+
"Normal": 10
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
|
| 35 |
+
|
| 36 |
+
def main():
|
| 37 |
+
# 创建输出目录
|
| 38 |
+
output_dir = './evaluation_results_pmc_clip_cat'
|
| 39 |
+
if not os.path.exists(output_dir):
|
| 40 |
+
os.makedirs(output_dir)
|
| 41 |
+
|
| 42 |
+
# 设置设备
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
print(f"使用设备: {device}")
|
| 45 |
+
|
| 46 |
+
# 加载模型配置
|
| 47 |
+
model_path = "logs/0321-Stoma-clip-train-cls/2025_03_21-23_45_18-model_RN50_fusion4-lr_1e-05-b_256-j_8-p_amp/checkpoints/epoch_150.pt"
|
| 48 |
+
model_name = "RN50_fusion4"
|
| 49 |
+
args = parse_args()
|
| 50 |
+
args.model = model_name
|
| 51 |
+
args.pretrained = model_path
|
| 52 |
+
args.device = device
|
| 53 |
+
args.mlm = True
|
| 54 |
+
args.train_data = "data/single_symptoms_test.jsonl"
|
| 55 |
+
args.image_dir = "./data/cleaned_data"
|
| 56 |
+
args.csv_img_key = "image"
|
| 57 |
+
args.csv_caption_key = "caption"
|
| 58 |
+
args.context_length = 77
|
| 59 |
+
args.num_classes = len(LABEL_MAP)
|
| 60 |
+
args.output_dir = output_dir
|
| 61 |
+
|
| 62 |
+
# 创建模型和预处理函数
|
| 63 |
+
model, _, preprocess = pmc_clip.create_model_and_transforms(args)
|
| 64 |
+
model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='concat')
|
| 65 |
+
|
| 66 |
+
# 加载模型权重
|
| 67 |
+
state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
|
| 68 |
+
|
| 69 |
+
state_dict_real = {}
|
| 70 |
+
for k, v in state_dict['state_dict'].items():
|
| 71 |
+
state_dict_real[k.replace("module.", "", 1)] = v
|
| 72 |
+
print(model.load_state_dict(state_dict_real))
|
| 73 |
+
model.to(device=device)
|
| 74 |
+
|
| 75 |
+
# 准备数据集
|
| 76 |
+
dataset = PmcDataset(args,
|
| 77 |
+
input_filename=args.train_data,
|
| 78 |
+
transforms=preprocess,
|
| 79 |
+
is_train=False)
|
| 80 |
+
|
| 81 |
+
test_loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
|
| 82 |
+
|
| 83 |
+
print(f"测试集样本数: {len(dataset)}")
|
| 84 |
+
|
| 85 |
+
# 收集预测结果
|
| 86 |
+
all_preds = []
|
| 87 |
+
all_probs = []
|
| 88 |
+
all_labels = []
|
| 89 |
+
|
| 90 |
+
print("开始评估...")
|
| 91 |
+
model.eval()
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
for batch in tqdm(test_loader):
|
| 94 |
+
labels = batch["cls_label"].to(device)
|
| 95 |
+
|
| 96 |
+
# 前向传播
|
| 97 |
+
outputs = model(batch)
|
| 98 |
+
|
| 99 |
+
# 获取预测结果
|
| 100 |
+
probs = torch.softmax(outputs, dim=1)
|
| 101 |
+
_, preds = torch.max(outputs, dim=1)
|
| 102 |
+
|
| 103 |
+
all_preds.extend(preds.cpu().numpy())
|
| 104 |
+
all_probs.extend(probs.cpu().numpy())
|
| 105 |
+
all_labels.extend(labels.cpu().numpy())
|
| 106 |
+
|
| 107 |
+
# 转换为numpy数组
|
| 108 |
+
all_preds = np.array(all_preds)
|
| 109 |
+
all_probs = np.array(all_probs)
|
| 110 |
+
all_labels = np.array(all_labels)
|
| 111 |
+
|
| 112 |
+
# 计算整体AUC(使用one-vs-rest策略的平均)
|
| 113 |
+
try:
|
| 114 |
+
y_true_bin = label_binarize(all_labels, classes=range(args.num_classes))
|
| 115 |
+
if args.num_classes == 2:
|
| 116 |
+
overall_fpr, overall_tpr, _ = roc_curve(y_true_bin[:, 1], all_probs[:, 1])
|
| 117 |
+
overall_auc = roc_auc_score(y_true_bin, all_probs[:, 1])
|
| 118 |
+
else:
|
| 119 |
+
overall_fpr, overall_tpr, _ = roc_curve(y_true_bin.ravel(), all_probs.ravel())
|
| 120 |
+
overall_auc = roc_auc_score(y_true_bin, all_probs, multi_class='ovr', average='micro')
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"计算整体AUC时出错: {e}")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# 保存整体ROC曲线数据
|
| 126 |
+
roc_data = {
|
| 127 |
+
"fpr": overall_fpr.tolist(),
|
| 128 |
+
"tpr": overall_tpr.tolist(),
|
| 129 |
+
"auc": overall_auc
|
| 130 |
+
}
|
| 131 |
+
roc_file = os.path.join(output_dir, "overall_roc_data.json")
|
| 132 |
+
with open(roc_file, "w") as f:
|
| 133 |
+
json.dump(roc_data, f)
|
| 134 |
+
print(f"整体ROC曲线数据已保存至: {roc_file}")
|
| 135 |
+
|
| 136 |
+
# 绘制ROC曲线
|
| 137 |
+
plt.figure(figsize=(8, 6))
|
| 138 |
+
plt.plot(overall_fpr, overall_tpr, label=f"Overall (AUC = {overall_auc:.4f})")
|
| 139 |
+
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
|
| 140 |
+
plt.xlim([0.0, 1.0])
|
| 141 |
+
plt.ylim([0.0, 1.05])
|
| 142 |
+
plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
|
| 143 |
+
plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
|
| 144 |
+
plt.title('Overall ROC Curve', fontsize=14)
|
| 145 |
+
plt.legend(loc="lower right", fontsize=10)
|
| 146 |
+
plt.grid(alpha=0.3)
|
| 147 |
+
plt.tight_layout()
|
| 148 |
+
plt.savefig(os.path.join(output_dir, 'overall_roc_curve.png'), dpi=300, bbox_inches='tight')
|
| 149 |
+
print(f"整体ROC曲线图已保存至: {os.path.join(output_dir, 'overall_roc_curve.png')}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == '__main__':
|
| 153 |
+
main()
|
setup.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Setup
|
| 2 |
+
"""
|
| 3 |
+
from setuptools import setup, find_packages
|
| 4 |
+
from codecs import open
|
| 5 |
+
from os import path
|
| 6 |
+
|
| 7 |
+
here = path.abspath(path.dirname(__file__))
|
| 8 |
+
|
| 9 |
+
# Get the long description from the README file
|
| 10 |
+
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
| 11 |
+
long_description = f.read()
|
| 12 |
+
|
| 13 |
+
exec(open('src/pmc_clip/version.py').read())
|
| 14 |
+
setup(
|
| 15 |
+
name='pmc_clip',
|
| 16 |
+
version=__version__,
|
| 17 |
+
description='PMC-CLIP: Contrastive Language-Image Pre-training using Biomedical Documents',
|
| 18 |
+
long_description=long_description,
|
| 19 |
+
long_description_content_type='text/markdown',
|
| 20 |
+
url='https://github.com/WeixiongLin/PMC-CLIP/',
|
| 21 |
+
author='weixiong',
|
| 22 |
+
author_email='wx_lin@sjtu.edu.cn',
|
| 23 |
+
classifiers=[
|
| 24 |
+
# How mature is this project? Common values are
|
| 25 |
+
# 3 - Alpha
|
| 26 |
+
# 4 - Beta
|
| 27 |
+
# 5 - Production/Stable
|
| 28 |
+
'Development Status :: 3 - Beta',
|
| 29 |
+
'Intended Audience :: Education',
|
| 30 |
+
'Intended Audience :: Science/Research',
|
| 31 |
+
'License :: OSI Approved :: Apache Software License',
|
| 32 |
+
'Programming Language :: Python :: 3.7',
|
| 33 |
+
'Programming Language :: Python :: 3.8',
|
| 34 |
+
'Programming Language :: Python :: 3.9',
|
| 35 |
+
'Programming Language :: Python :: 3.10',
|
| 36 |
+
'Topic :: Scientific/Engineering',
|
| 37 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
| 38 |
+
'Topic :: Software Development',
|
| 39 |
+
'Topic :: Software Development :: Libraries',
|
| 40 |
+
'Topic :: Software Development :: Libraries :: Python Modules',
|
| 41 |
+
],
|
| 42 |
+
|
| 43 |
+
# Note that this is a string of words separated by whitespace, not a list.
|
| 44 |
+
keywords='PMC-CLIP',
|
| 45 |
+
package_dir={'': 'src'},
|
| 46 |
+
packages=find_packages(where='src', exclude=['training']),
|
| 47 |
+
include_package_data=True,
|
| 48 |
+
install_requires=[
|
| 49 |
+
'torch >= 1.9',
|
| 50 |
+
'torchvision',
|
| 51 |
+
'transformers <= 4.21.0',
|
| 52 |
+
'ftfy',
|
| 53 |
+
'regex',
|
| 54 |
+
'tqdm',
|
| 55 |
+
],
|
| 56 |
+
python_requires='>=3.7',
|
| 57 |
+
)
|
stoma_clip.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b1b2a21c2d8d66669f70cfc9380143d758c7937a45c2c8d06747860013c51ed
|
| 3 |
+
size 832509306
|
vis_all_model_roc.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
def load_roc_data(roc_files):
|
| 6 |
+
"""
|
| 7 |
+
加载多个模型的ROC数据
|
| 8 |
+
参数:
|
| 9 |
+
roc_files (list): 包含多个模型的overall_roc_data.json文件路径的列表
|
| 10 |
+
返回:
|
| 11 |
+
roc_data_list (list): 包含每个模型的ROC数据字典的列表
|
| 12 |
+
"""
|
| 13 |
+
roc_data_list = []
|
| 14 |
+
for roc_file in roc_files:
|
| 15 |
+
with open(roc_file, "r") as f:
|
| 16 |
+
roc_data = json.load(f)
|
| 17 |
+
roc_data_list.append(roc_data)
|
| 18 |
+
return roc_data_list
|
| 19 |
+
|
| 20 |
+
def plot_combined_roc(roc_data_list, model_names, output_path):
|
| 21 |
+
"""
|
| 22 |
+
绘制多个模型的ROC曲线到同一张图中
|
| 23 |
+
参数:
|
| 24 |
+
roc_data_list (list): 包含每个模型的ROC数据字典的列表
|
| 25 |
+
model_names (list): 每个模型的名称列表
|
| 26 |
+
output_path (str): 保存ROC曲线图的路径
|
| 27 |
+
"""
|
| 28 |
+
plt.figure(figsize=(10, 8))
|
| 29 |
+
|
| 30 |
+
for roc_data, model_name in zip(roc_data_list, model_names):
|
| 31 |
+
fpr = roc_data["fpr"]
|
| 32 |
+
tpr = roc_data["tpr"]
|
| 33 |
+
auc_value = roc_data["auc"]
|
| 34 |
+
plt.plot(fpr, tpr, label=f"{model_name} (AUC = {auc_value:.4f})")
|
| 35 |
+
|
| 36 |
+
# 绘制对角线
|
| 37 |
+
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
|
| 38 |
+
|
| 39 |
+
# 图形设置
|
| 40 |
+
plt.xlim([0.0, 1.0])
|
| 41 |
+
plt.ylim([0.0, 1.05])
|
| 42 |
+
plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
|
| 43 |
+
plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
|
| 44 |
+
plt.title('Combined ROC Curves for Multiple Models', fontsize=14)
|
| 45 |
+
plt.legend(loc="lower right", fontsize=10)
|
| 46 |
+
plt.grid(alpha=0.3)
|
| 47 |
+
plt.tight_layout()
|
| 48 |
+
|
| 49 |
+
# 保存图像
|
| 50 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 51 |
+
print(f"ROC曲线图已保存至: {output_path}")
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
# 定义存放多个模型ROC数据的目录
|
| 55 |
+
roc_data_dir = "./roc_result" # 替换为实际路径
|
| 56 |
+
output_path = "./combined_roc_curve.png" # 保存最终ROC曲线图的路径
|
| 57 |
+
|
| 58 |
+
# 获取所有模型的overall_roc_data.json文件路径
|
| 59 |
+
roc_files = [os.path.join(roc_data_dir, f) for f in os.listdir(roc_data_dir)]
|
| 60 |
+
|
| 61 |
+
# 模型名称(从文件名提取)
|
| 62 |
+
model_names = [os.path.basename(f).replace(".json", "") for f in roc_files]
|
| 63 |
+
|
| 64 |
+
# 加载所有模型的ROC数据
|
| 65 |
+
roc_data_list = load_roc_data(roc_files)
|
| 66 |
+
|
| 67 |
+
# 绘制并保存组合ROC曲线
|
| 68 |
+
plot_combined_roc(roc_data_list, model_names, output_path)
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
main()
|