Xiaomeng1130 commited on
Commit
8274db5
·
verified ·
1 Parent(s): 6ecc5ed

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitignore +140 -0
  2. LICENSE +21 -0
  3. README.md +166 -12
  4. app.py +146 -0
  5. requirements.txt +2 -1
  6. save_roc.py +153 -0
  7. setup.py +57 -0
  8. stoma_clip.pt +3 -0
  9. vis_all_model_roc.py +71 -0
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+
132
+ data/
133
+ microsoft/
134
+ ckpt/
135
+ # *.jpg
136
+ *.png
137
+ logs/
138
+ *.json
139
+ evaluation*/
140
+ *.ipynb
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Weixiong Lin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,166 @@
1
- ---
2
- title: Stoma Clip Api
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PMC-CLIP
2
+
3
+ [![Quick Start Demo](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1P7uyzK_Mhu1YyMeRrrRY_e3NpkNBOI4L?usp=sharing)
4
+ [![Dataset and Model](https://img.shields.io/badge/Hugging%20Face-Dataset-green)](https://huggingface.co/datasets/axiong/pmc_oa)
5
+
6
+ The dataset and checkpoint is available at [Huggingface](https://huggingface.co/datasets/axiong/pmc-oa), [Baidu Cloud](https://pan.baidu.com/s/1mD51oOYbIOqDJSeiPNaCCg)(key: 3iqf).
7
+
8
+ 📢 We provide the extracted image encoder and text encoder checkpoint in [Huggingface](https://huggingface.co/datasets/axiong/pmc-oa), and a quick start demo on how to use them in encoding image and text input. Check this [notebook](https://colab.research.google.com/drive/1P7uyzK_Mhu1YyMeRrrRY_e3NpkNBOI4L?usp=sharing)!
9
+
10
+
11
+ - [PMC-CLIP](#pmc-clip)
12
+ - [Quick Start Inference](#quick-start-inference)
13
+ - [Train and Evaluation](#train-and-evaluation)
14
+ - [1. Create Environment](#1-create-environment)
15
+ - [2. Prepare Dataset](#2-prepare-dataset)
16
+ - [3. Training](#3-training)
17
+ - [4. Evaluation](#4-evaluation)
18
+ - [Acknowledgement](#acknowledgement)
19
+ - [Contribution](#contribution)
20
+ - [TODO](#todo)
21
+ - [Cite](#cite)
22
+
23
+ ## Quick Start Inference
24
+
25
+ We offer a quick start demo on how to use the image and text encoder of PMC-CLIP. Check this [notebook](https://colab.research.google.com/drive/1P7uyzK_Mhu1YyMeRrrRY_e3NpkNBOI4L?usp=sharing)!
26
+
27
+
28
+ ## Train and Evaluation
29
+
30
+ Repo Structure
31
+ ```bash
32
+ src/:
33
+ |--setup.py
34
+ |--pmc_clip/
35
+ | |--loss/
36
+ | |--model/: PMC-CLIP model and variants
37
+ | |--model_configs/
38
+ | |--factory.py: Create model according to configs
39
+ | |--transform.py: data augmentation
40
+ |--training/
41
+ | |--main.py
42
+ | |--scheduler.py: Learning rate scheduler
43
+ | |--train.py
44
+ | |--evaluate.py
45
+ | |--data.py
46
+ | |--params.py
47
+ docs/: project pages
48
+ ```
49
+
50
+ ### 1. Create Environment
51
+
52
+ ```bash
53
+ conda create -n pmc_clip python=3.8
54
+ conda activate pmc_clip
55
+
56
+ pip install -r requirements.txt
57
+ # pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
58
+
59
+ python setup.py develop # install pmc_clip with dev mode
60
+ ```
61
+
62
+ ### 2. Prepare Dataset
63
+
64
+ Download from [Huggingface](https://huggingface.co/datasets/axiong/pmc-oa), [Baidu Cloud](https://pan.baidu.com/s/1mD51oOYbIOqDJSeiPNaCCg)(key: 3iqf).
65
+ Or follow the [Pipeline of PMC-OA Development](https://github.com/WeixiongLin/Build-PMC-OA) if you want to start from scratch.
66
+
67
+
68
+ ### 3. Training
69
+
70
+ Single GPU
71
+ ```bash
72
+ python -m training.main \
73
+ --dataset-type "csv" --csv-separator "," --save-frequency 5 \
74
+ --report-to tensorboard \
75
+ --train-data="path/to/train.csv" --val-data="path/to/valid.csv" \
76
+ --csv-img-key image --csv-caption-key caption \
77
+ --warmup 500 --batch-size=8 --lr=1e-4 --wd=0.1 --epochs=100 --workers=8 \
78
+ --model RN50_fusion4 --hugging-face --mlm --crop-scale 0.5
79
+ ```
80
+
81
+ Multi GPU
82
+ ```bash
83
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 --rdzv_endpoint=$HOSTE_NODE_ADDR -m training.main \
84
+ --dataset-type "csv" --csv-separator "," --save-frequency 5 \
85
+ --report-to tensorboard \
86
+ --train-data="path/to/train.csv" --val-data="path/to/valid.csv" \
87
+ --csv-img-key image --csv-caption-key caption \
88
+ --warmup 500 --batch-size=128 --lr=1e-4 --wd=0.1 --epochs=100 --workers=8 \
89
+ --model RN50_fusion4 --hugging-face --mlm --crop-scale 0.5
90
+ ```
91
+
92
+ <div class="third">
93
+ <img src="docs/resources/train_loss.png" style="height:200px">
94
+ <img src="docs/resources/val_i2t@1.png" style="height:200px">
95
+ <img src="docs/resources/val_t2i@1.png" style="height:200px">
96
+ </div>
97
+
98
+
99
+ ### 4. Evaluation
100
+ Load checkpoint and eval on 2k samples from testset.
101
+
102
+ ```bash
103
+ python -m training.main \
104
+ --dataset-type "csv" --csv-separator "," --report-to tensorboard \
105
+ --val-data="path/to/test.csv" \
106
+ --csv-img-key image --csv-caption-key caption \
107
+ --batch-size=32 --workers=8 \
108
+ --model RN50_fusion4 --hugging-face --mlm --crop-scale 0.1 \
109
+ --resume /path/to/checkpoint.pt \
110
+ --test-2000
111
+ ```
112
+
113
+ Also we provide automatic ways to load model weights from huggingface repo.
114
+
115
+ | Model | URL |
116
+ | --- | --- |
117
+ | PMC_CLIP:beta | https://huggingface.co/datasets/axiong/pmc_oa_beta/blob/main/checkpoint.pt |
118
+
119
+
120
+ Take PMC_CLIP:beta checkpoint as an example:
121
+ ```bash
122
+ python -m training.main \
123
+ --dataset-type "csv" --csv-separator "," --report-to tensorboard \
124
+ --val-data="path/to/test.csv" \
125
+ --csv-img-key image --csv-caption-key caption \
126
+ --batch-size=32 --workers=8 \
127
+ --model RN50_fusion4 --hugging-face --mlm --crop-scale 0.1 \
128
+ --resume "PMC_CLIP:beta" \
129
+ --test-2000
130
+ ```
131
+
132
+
133
+ ## Acknowledgement
134
+ The code is based on [OpenCLIP](https://github.com/mlfoundations/open_clip) and [M3AE](https://github.com/zhjohnchan/M3AE). We thank the authors for their open-sourced code and encourage users to cite their works when applicable.
135
+
136
+ Note that our code don't supported tools like horovod, wandb in OpenCLIP. But we keep the code from OpenCLIP for consistency.
137
+
138
+ ## Contribution
139
+ Please raise an issue if you need help, any contributions are welcomed.
140
+
141
+ ## TODO
142
+
143
+ * [ ] Compatibility testing on more env settings
144
+ * [ ] Support for horovod, wandb
145
+
146
+
147
+ ## Cite
148
+ ```bash
149
+ @article{lin2023pmc,
150
+ title={PMC-CLIP: Contrastive Language-Image Pre-training using Biomedical Documents},
151
+ author={Lin, Weixiong and Zhao, Ziheng and Zhang, Xiaoman and Wu, Chaoyi and Zhang, Ya and Wang, Yanfeng and Xie, Weidi},
152
+ journal={arXiv preprint arXiv:2303.07240},
153
+ year={2023}
154
+ }
155
+ ```
156
+
157
+ The paper has been accepted by MICCAI 2023.
158
+ ```bash
159
+ @inproceedings{lin2023pmc,
160
+ title={Pmc-clip: Contrastive language-image pre-training using biomedical documents},
161
+ author={Lin, Weixiong and Zhao, Ziheng and Zhang, Xiaoman and Wu, Chaoyi and Zhang, Ya and Wang, Yanfeng and Xie, Weidi},
162
+ booktitle={MICCAI},
163
+ year={2023}
164
+ }
165
+ ```
166
+
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # ========== 1. Import project modules ==========
8
+ try:
9
+ from stoma_clip import pmc_clip
10
+ from stoma_clip.pmc_clip.factory import _rescan_model_configs
11
+ from stoma_clip.training.fusion_method import convert_model_to_cls
12
+ from stoma_clip.training.dataset.utils import encode_mlm
13
+ except ImportError as e:
14
+ print(f"Error importing Stoma-CLIP modules: {e}")
15
+
16
+ # ========== 2. Model Configuration and Loading ==========
17
+ LABEL_MAP = {
18
+ "Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2,
19
+ "Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6,
20
+ "Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10
21
+ }
22
+ REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
23
+ NUM_CLASSES = len(LABEL_MAP)
24
+
25
+
26
+ class Args:
27
+ def __init__(self):
28
+ self.model = "RN50_fusion4"
29
+ self.pretrained = "stoma_clip.pt"
30
+ self.num_classes = NUM_CLASSES
31
+ self.mlm = True
32
+ self.crop_scale = 0.9
33
+ self.context_length = 77
34
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ print(f"Using device: {self.device}")
36
+
37
+
38
+ args = Args()
39
+ MODEL = None
40
+ PREPROCESS = None
41
+ TOKENIZER = None
42
+
43
+
44
+ def load_model():
45
+ """Load model once when Gradio starts."""
46
+ global MODEL, PREPROCESS, TOKENIZER
47
+ if MODEL is not None:
48
+ return MODEL, PREPROCESS, TOKENIZER
49
+
50
+ try:
51
+ _rescan_model_configs()
52
+ model, _, preprocess = pmc_clip.create_model_and_transforms(args)
53
+ model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention')
54
+ model.to(args.device).eval()
55
+
56
+ state_dict = torch.load(args.pretrained, map_location='cpu')
57
+ state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()}
58
+ model.load_state_dict(state_dict_clean)
59
+ tokenizer = model.tokenizer
60
+
61
+ MODEL = model
62
+ PREPROCESS = preprocess
63
+ TOKENIZER = tokenizer
64
+ print("Stoma-CLIP Model loaded successfully!")
65
+ return MODEL, PREPROCESS, TOKENIZER
66
+
67
+ except Exception as e:
68
+ print(f"Error during model loading: {e}")
69
+ MODEL = None
70
+ raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}")
71
+
72
+
73
+ # ========== 3. Inference Function ==========
74
+ def predict_stoma_clip(image: Image.Image, caption: str):
75
+ if MODEL is None:
76
+ return "Model Loading Failed", {}
77
+
78
+ image = image.convert("RGB")
79
+ model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER
80
+ device = args.device
81
+
82
+ image_tensor = preprocess(image).unsqueeze(0).to(device)
83
+
84
+ mask_token, pad_token = '[MASK]', '[PAD]'
85
+ vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens]
86
+
87
+ bert_input, bert_label = encode_mlm(
88
+ caption=caption,
89
+ vocab=vocab,
90
+ mask_token=mask_token,
91
+ pad_token=pad_token,
92
+ ratio=0.0,
93
+ tokenizer=tokenizer,
94
+ args=args,
95
+ )
96
+
97
+ with torch.no_grad():
98
+ inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label}
99
+ outputs = model(inputs)
100
+ probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
101
+ predicted_class_idx = torch.argmax(outputs, dim=1).item()
102
+
103
+ predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown")
104
+ probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
105
+
106
+ return predicted_class_name, probability_distribution
107
+
108
+
109
+ # ========== 4. Gradio Interface Setup ==========
110
+ try:
111
+ load_model()
112
+ print("模型已在 Gradio 启动前成功加载。")
113
+ except Exception as e:
114
+ print(f"致命错误:模型未能加载。Gradio 界面将无法运行。{e}")
115
+
116
+ image_input = gr.Image(type="pil", label="上传造口图片")
117
+ caption_input = gr.Textbox(label="输入造口描述文本 (例如: Exudate, epidermal breakdown, ...)")
118
+ predicted_label_output = gr.Textbox(label="预测类别")
119
+ prob_output = gr.Label(label="类别概率分布")
120
+
121
+ # Find example path for Gradio demo (Note: In the deployed Space, these paths should be relative to the root)
122
+ try:
123
+ example_path_1 = "demo/Irritant_dermatitis.jpg"
124
+ example_path_2 = "demo/Folliculitis.jpg"
125
+
126
+ examples_list = []
127
+ if os.path.exists(example_path_1):
128
+ examples_list.append(
129
+ [example_path_1, "Exudate, epidermal breakdown, irregular erythema, pain, confined to contact areas"])
130
+ elif os.path.exists(example_path_2):
131
+ examples_list.append([example_path_2, "Erythema, papules, pustules confined to hair follicles"])
132
+ except Exception:
133
+ examples_list = []
134
+
135
+ iface = gr.Interface(
136
+ fn=predict_stoma_clip,
137
+ inputs=[image_input, caption_input],
138
+ outputs=[predicted_label_output, prob_output],
139
+ title="🧪 Stoma-CLIP 分类 API 原型 (Gradio)",
140
+ description="请上传造口图片并输入临床描述,模型将预测最可能的皮肤并发症类别。",
141
+ examples=examples_list,
142
+ allow_flagging="never"
143
+ )
144
+
145
+ if __name__ == "__main__":
146
+ iface.launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- pytorch >= 1.9.0
2
  torchvision
3
  transformers
4
  tqdm
@@ -10,3 +10,4 @@ braceexpand
10
  webdataset
11
  jsonlines
12
  tensorboard
 
 
1
+ torch
2
  torchvision
3
  transformers
4
  tqdm
 
10
  webdataset
11
  jsonlines
12
  tensorboard
13
+ matplotlib
save_roc.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from sklearn.metrics import roc_curve, auc, confusion_matrix, roc_auc_score
9
+ from sklearn.preprocessing import label_binarize
10
+ import json
11
+ import sys
12
+ sys.path.append('.')
13
+ import pmc_clip
14
+ from training.params import parse_args
15
+ from training.data import PmcDataset
16
+ from training.fusion_method import convert_model_to_cls
17
+
18
+
19
+ # 标签映射
20
+ LABEL_MAP = {
21
+ "Irritant dermatitis": 0,
22
+ "Allergic contact dermatitis": 1,
23
+ "Mechanical injury": 2,
24
+ "Folliculitis": 3,
25
+ "Fungal infection": 4,
26
+ "Skin hyperplasia": 5,
27
+ "Parastomal varices": 6,
28
+ "Urate crystals": 7,
29
+ "Cancerous metastasis": 8,
30
+ "Pyoderma gangrenosum": 9,
31
+ "Normal": 10
32
+ }
33
+
34
+ REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
35
+
36
+ def main():
37
+ # 创建输出目录
38
+ output_dir = './evaluation_results_pmc_clip_cat'
39
+ if not os.path.exists(output_dir):
40
+ os.makedirs(output_dir)
41
+
42
+ # 设置设备
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ print(f"使用设备: {device}")
45
+
46
+ # 加载模型配置
47
+ model_path = "logs/0321-Stoma-clip-train-cls/2025_03_21-23_45_18-model_RN50_fusion4-lr_1e-05-b_256-j_8-p_amp/checkpoints/epoch_150.pt"
48
+ model_name = "RN50_fusion4"
49
+ args = parse_args()
50
+ args.model = model_name
51
+ args.pretrained = model_path
52
+ args.device = device
53
+ args.mlm = True
54
+ args.train_data = "data/single_symptoms_test.jsonl"
55
+ args.image_dir = "./data/cleaned_data"
56
+ args.csv_img_key = "image"
57
+ args.csv_caption_key = "caption"
58
+ args.context_length = 77
59
+ args.num_classes = len(LABEL_MAP)
60
+ args.output_dir = output_dir
61
+
62
+ # 创建模型和预处理函数
63
+ model, _, preprocess = pmc_clip.create_model_and_transforms(args)
64
+ model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='concat')
65
+
66
+ # 加载模型权重
67
+ state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
68
+
69
+ state_dict_real = {}
70
+ for k, v in state_dict['state_dict'].items():
71
+ state_dict_real[k.replace("module.", "", 1)] = v
72
+ print(model.load_state_dict(state_dict_real))
73
+ model.to(device=device)
74
+
75
+ # 准备数据集
76
+ dataset = PmcDataset(args,
77
+ input_filename=args.train_data,
78
+ transforms=preprocess,
79
+ is_train=False)
80
+
81
+ test_loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
82
+
83
+ print(f"测试集样本数: {len(dataset)}")
84
+
85
+ # 收集预测结果
86
+ all_preds = []
87
+ all_probs = []
88
+ all_labels = []
89
+
90
+ print("开始评估...")
91
+ model.eval()
92
+ with torch.no_grad():
93
+ for batch in tqdm(test_loader):
94
+ labels = batch["cls_label"].to(device)
95
+
96
+ # 前向传播
97
+ outputs = model(batch)
98
+
99
+ # 获取预测结果
100
+ probs = torch.softmax(outputs, dim=1)
101
+ _, preds = torch.max(outputs, dim=1)
102
+
103
+ all_preds.extend(preds.cpu().numpy())
104
+ all_probs.extend(probs.cpu().numpy())
105
+ all_labels.extend(labels.cpu().numpy())
106
+
107
+ # 转换为numpy数组
108
+ all_preds = np.array(all_preds)
109
+ all_probs = np.array(all_probs)
110
+ all_labels = np.array(all_labels)
111
+
112
+ # 计算整体AUC(使用one-vs-rest策略的平均)
113
+ try:
114
+ y_true_bin = label_binarize(all_labels, classes=range(args.num_classes))
115
+ if args.num_classes == 2:
116
+ overall_fpr, overall_tpr, _ = roc_curve(y_true_bin[:, 1], all_probs[:, 1])
117
+ overall_auc = roc_auc_score(y_true_bin, all_probs[:, 1])
118
+ else:
119
+ overall_fpr, overall_tpr, _ = roc_curve(y_true_bin.ravel(), all_probs.ravel())
120
+ overall_auc = roc_auc_score(y_true_bin, all_probs, multi_class='ovr', average='micro')
121
+ except Exception as e:
122
+ print(f"计算整体AUC时出错: {e}")
123
+ return
124
+
125
+ # 保存整体ROC曲线数据
126
+ roc_data = {
127
+ "fpr": overall_fpr.tolist(),
128
+ "tpr": overall_tpr.tolist(),
129
+ "auc": overall_auc
130
+ }
131
+ roc_file = os.path.join(output_dir, "overall_roc_data.json")
132
+ with open(roc_file, "w") as f:
133
+ json.dump(roc_data, f)
134
+ print(f"整体ROC曲线数据已保存至: {roc_file}")
135
+
136
+ # 绘制ROC曲线
137
+ plt.figure(figsize=(8, 6))
138
+ plt.plot(overall_fpr, overall_tpr, label=f"Overall (AUC = {overall_auc:.4f})")
139
+ plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
140
+ plt.xlim([0.0, 1.0])
141
+ plt.ylim([0.0, 1.05])
142
+ plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
143
+ plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
144
+ plt.title('Overall ROC Curve', fontsize=14)
145
+ plt.legend(loc="lower right", fontsize=10)
146
+ plt.grid(alpha=0.3)
147
+ plt.tight_layout()
148
+ plt.savefig(os.path.join(output_dir, 'overall_roc_curve.png'), dpi=300, bbox_inches='tight')
149
+ print(f"整体ROC曲线图已保存至: {os.path.join(output_dir, 'overall_roc_curve.png')}")
150
+
151
+
152
+ if __name__ == '__main__':
153
+ main()
setup.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Setup
2
+ """
3
+ from setuptools import setup, find_packages
4
+ from codecs import open
5
+ from os import path
6
+
7
+ here = path.abspath(path.dirname(__file__))
8
+
9
+ # Get the long description from the README file
10
+ with open(path.join(here, 'README.md'), encoding='utf-8') as f:
11
+ long_description = f.read()
12
+
13
+ exec(open('src/pmc_clip/version.py').read())
14
+ setup(
15
+ name='pmc_clip',
16
+ version=__version__,
17
+ description='PMC-CLIP: Contrastive Language-Image Pre-training using Biomedical Documents',
18
+ long_description=long_description,
19
+ long_description_content_type='text/markdown',
20
+ url='https://github.com/WeixiongLin/PMC-CLIP/',
21
+ author='weixiong',
22
+ author_email='wx_lin@sjtu.edu.cn',
23
+ classifiers=[
24
+ # How mature is this project? Common values are
25
+ # 3 - Alpha
26
+ # 4 - Beta
27
+ # 5 - Production/Stable
28
+ 'Development Status :: 3 - Beta',
29
+ 'Intended Audience :: Education',
30
+ 'Intended Audience :: Science/Research',
31
+ 'License :: OSI Approved :: Apache Software License',
32
+ 'Programming Language :: Python :: 3.7',
33
+ 'Programming Language :: Python :: 3.8',
34
+ 'Programming Language :: Python :: 3.9',
35
+ 'Programming Language :: Python :: 3.10',
36
+ 'Topic :: Scientific/Engineering',
37
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
38
+ 'Topic :: Software Development',
39
+ 'Topic :: Software Development :: Libraries',
40
+ 'Topic :: Software Development :: Libraries :: Python Modules',
41
+ ],
42
+
43
+ # Note that this is a string of words separated by whitespace, not a list.
44
+ keywords='PMC-CLIP',
45
+ package_dir={'': 'src'},
46
+ packages=find_packages(where='src', exclude=['training']),
47
+ include_package_data=True,
48
+ install_requires=[
49
+ 'torch >= 1.9',
50
+ 'torchvision',
51
+ 'transformers <= 4.21.0',
52
+ 'ftfy',
53
+ 'regex',
54
+ 'tqdm',
55
+ ],
56
+ python_requires='>=3.7',
57
+ )
stoma_clip.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b1b2a21c2d8d66669f70cfc9380143d758c7937a45c2c8d06747860013c51ed
3
+ size 832509306
vis_all_model_roc.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import matplotlib.pyplot as plt
4
+
5
+ def load_roc_data(roc_files):
6
+ """
7
+ 加载多个模型的ROC数据
8
+ 参数:
9
+ roc_files (list): 包含多个模型的overall_roc_data.json文件路径的列表
10
+ 返回:
11
+ roc_data_list (list): 包含每个模型的ROC数据字典的列表
12
+ """
13
+ roc_data_list = []
14
+ for roc_file in roc_files:
15
+ with open(roc_file, "r") as f:
16
+ roc_data = json.load(f)
17
+ roc_data_list.append(roc_data)
18
+ return roc_data_list
19
+
20
+ def plot_combined_roc(roc_data_list, model_names, output_path):
21
+ """
22
+ 绘制多个模型的ROC曲线到同一张图中
23
+ 参数:
24
+ roc_data_list (list): 包含每个模型的ROC数据字典的列表
25
+ model_names (list): 每个模型的名称列表
26
+ output_path (str): 保存ROC曲线图的路径
27
+ """
28
+ plt.figure(figsize=(10, 8))
29
+
30
+ for roc_data, model_name in zip(roc_data_list, model_names):
31
+ fpr = roc_data["fpr"]
32
+ tpr = roc_data["tpr"]
33
+ auc_value = roc_data["auc"]
34
+ plt.plot(fpr, tpr, label=f"{model_name} (AUC = {auc_value:.4f})")
35
+
36
+ # 绘制对角线
37
+ plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
38
+
39
+ # 图形设置
40
+ plt.xlim([0.0, 1.0])
41
+ plt.ylim([0.0, 1.05])
42
+ plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
43
+ plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
44
+ plt.title('Combined ROC Curves for Multiple Models', fontsize=14)
45
+ plt.legend(loc="lower right", fontsize=10)
46
+ plt.grid(alpha=0.3)
47
+ plt.tight_layout()
48
+
49
+ # 保存图像
50
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
51
+ print(f"ROC曲线图已保存至: {output_path}")
52
+
53
+ def main():
54
+ # 定义存放多个模型ROC数据的目录
55
+ roc_data_dir = "./roc_result" # 替换为实际路径
56
+ output_path = "./combined_roc_curve.png" # 保存最终ROC曲线图的路径
57
+
58
+ # 获取所有模型的overall_roc_data.json文件路径
59
+ roc_files = [os.path.join(roc_data_dir, f) for f in os.listdir(roc_data_dir)]
60
+
61
+ # 模型名称(从文件名提取)
62
+ model_names = [os.path.basename(f).replace(".json", "") for f in roc_files]
63
+
64
+ # 加载所有模型的ROC数据
65
+ roc_data_list = load_roc_data(roc_files)
66
+
67
+ # 绘制并保存组合ROC曲线
68
+ plot_combined_roc(roc_data_list, model_names, output_path)
69
+
70
+ if __name__ == "__main__":
71
+ main()