Spaces:
Paused
Paused
alessandro trinca tornidor
commited on
Commit
•
b0660fb
1
Parent(s):
95c07ff
feat: zeroGPU spaces support (drop docker, uses gradio sdk)
Browse files- Dockerfile +0 -61
- README.md +21 -319
- lisa_on_cuda/app/main.py → app.py +17 -5
- lisa_on_cuda/app/__init__.py +0 -0
- lisa_on_cuda/app/chat.py +0 -200
- lisa_on_cuda/app/merge_lora_weights_and_save_hf_model.py +0 -159
- lisa_on_cuda/app/train_ds.py +0 -584
- lisa_on_cuda/{app/routes.py → routes.py} +1 -1
- lisa_on_cuda/utils/app_helpers.py +104 -37
- requirements.txt +3 -2
Dockerfile
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
FROM nvcr.io/nvidia/pytorch:24.03-py3
|
2 |
-
|
3 |
-
LABEL authors="alessandro@trinca.tornidor.com"
|
4 |
-
|
5 |
-
ARG DEBIAN_FRONTEND=noninteractive
|
6 |
-
ARG WORKDIR="/var/task"
|
7 |
-
|
8 |
-
ENV PYTHONUNBUFFERED=1
|
9 |
-
ENV PYTHONPATH=${WORKDIR}:${WORKDIR}/venv:${PYTHONPATH}
|
10 |
-
ENV PATH=${WORKDIR}/venv/bin:$PATH
|
11 |
-
ENV XDG_CACHE_HOME=/data
|
12 |
-
|
13 |
-
WORKDIR ${WORKDIR}
|
14 |
-
COPY . ${WORKDIR}/
|
15 |
-
RUN ls ${WORKDIR}/
|
16 |
-
RUN mkdir -p ${XDG_CACHE_HOME}/.cache
|
17 |
-
RUN chmod 770 ${XDG_CACHE_HOME}/.cache
|
18 |
-
|
19 |
-
RUN apt update && apt upgrade -y && apt install --no-install-recommends -y \
|
20 |
-
build-essential \
|
21 |
-
python3.11 \
|
22 |
-
python3-pip \
|
23 |
-
python3-dev \
|
24 |
-
python3-venv \
|
25 |
-
git \
|
26 |
-
ffmpeg \
|
27 |
-
curl \
|
28 |
-
&& apt clean && rm -rf /var/lib/apt/lists/*
|
29 |
-
|
30 |
-
RUN which python3
|
31 |
-
RUN python3 --version
|
32 |
-
RUN python3 -m venv venv
|
33 |
-
RUN source ${WORKDIR}/venv/bin/activate python -m pip install pip --upgrade && python -m pip install -r ${WORKDIR}/requirements.txt
|
34 |
-
RUN source ${WORKDIR}/venv/bin/activate && which python && python --version
|
35 |
-
RUN chmod +x ${WORKDIR}/scripts/entrypoint.sh
|
36 |
-
RUN curl -o /tmp/frpc_linux_amd64_v0.2 https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_linux_amd64
|
37 |
-
RUN ls -l /tmp/frpc_linux_amd64_v0.2
|
38 |
-
RUN cp /tmp/frpc_linux_amd64_v0.2 ${WORKDIR}/venv/lib/python*/site-packages/gradio
|
39 |
-
RUN ls -l ${WORKDIR}/venv/lib/python*/site-packages/gradio
|
40 |
-
RUN ls -l ${WORKDIR}/venv/bin
|
41 |
-
RUN bash --version
|
42 |
-
RUN chmod 770 ${WORKDIR}/flagged/
|
43 |
-
RUN chmod 770 ${WORKDIR}/flagged/* || true
|
44 |
-
RUN ls -ld ${WORKDIR}/flagged/
|
45 |
-
RUN ls -ld ${WORKDIR}/flagged/* || echo "folders ${WORKDIR}/flagged/* not found"
|
46 |
-
RUN ls -l ${WORKDIR}
|
47 |
-
RUN ls -l ${WORKDIR}/scripts/
|
48 |
-
RUN ls -l ${WORKDIR}/scripts/entrypoint.sh
|
49 |
-
|
50 |
-
EXPOSE 7860
|
51 |
-
|
52 |
-
CMD ["/var/task/scripts/entrypoint.sh"]
|
53 |
-
# CMD [
|
54 |
-
# "/var/task/scripts/entrypoint.sh",
|
55 |
-
# "/var/task/venv/bin/uvicorn", "app:lisa_app",
|
56 |
-
# "--host", "0.0.0.0",
|
57 |
-
# "--port", "7860",
|
58 |
-
# "--version='xinlai/LISA-13B-llama2-v1-explanatory'",
|
59 |
-
# "--precision='fp16'",
|
60 |
-
# "--load_in_4bit"
|
61 |
-
# ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,20 +1,25 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
|
|
|
|
|
8 |
---
|
9 |
|
10 |
-
#
|
|
|
|
|
11 |
|
12 |
1. checkout repo, install venv with jupyter
|
13 |
2. port forwarding in localhost wiht private key: `ssh -i /path/to/private_key name@endpoint.com -L 8889:localhost:8889 -N -f`
|
14 |
3. start the jupyter-lab server
|
15 |
4. connect to page in localhost
|
16 |
|
17 |
-
## Commands to work on
|
|
|
18 |
```bash
|
19 |
cd ~/workspace/lisa-on-cuda/
|
20 |
rm -rf lisa_venv
|
@@ -38,320 +43,17 @@ To run the `test.ipynb` notebook you should already:
|
|
38 |
- installed jupyterlab dependencies from requirements_jupyter.txt
|
39 |
- installed dependencies from requirements.txt
|
40 |
|
41 |
-
## Hardware requirements
|
|
|
42 |
- an nvidia gpu with 10 or 12GB of memory (a T4 should suffice)
|
43 |
- at least 16GB of system ram
|
44 |
|
45 |
-
|
46 |
-
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/openxlab-app/LISA)
|
47 |
-
|
48 |
-
# LISA: Reasoning Segmentation via Large Language Model
|
49 |
-
|
50 |
-
<font size=7><div align='center'><b>LISA</b>: Large <b>L</b>anguage <b>I</b>nstructed <b>S</b>egmentation <b>A</b>ssistant</div></font>
|
51 |
-
|
52 |
-
<font size=7><div align='center' > <a href=https://arxiv.org/pdf/2308.00692.pdf>**Paper**</a> | <a href="https://huggingface.co/xinlai">**Models**</a> | [**Training**](#training) | [**Inference**](#inference) | [**Local Deployment**](#deployment) | [**Dataset**](#dataset) | <a href="http://103.170.5.190:7860/">**Online Demo**</a></div></font>
|
53 |
-
|
54 |
-
<!-- <p align="center"> <img src="resources/imgs/teaser.jpg" width="100%"> </p> -->
|
55 |
-
|
56 |
-
<table class="center">
|
57 |
-
<tr>
|
58 |
-
<td style="text-align:center;"><b>Input</b></td>
|
59 |
-
<td style="text-align:center;"><b>Output</b></td>
|
60 |
-
<td style="text-align:center;"><b>Input</b></td>
|
61 |
-
<td style="text-align:center;"><b>Output</b></td>
|
62 |
-
</tr>
|
63 |
-
<tr>
|
64 |
-
<td><img src="resources/imgs/obama.jpg"></td>
|
65 |
-
<td><img src="vis_output/obama.jpg"></td>
|
66 |
-
<td><img src="resources/imgs/trump.jpg"></td>
|
67 |
-
<td><img src="vis_output/trump.jpg"></td>
|
68 |
-
</tr>
|
69 |
-
<tr>
|
70 |
-
<td width=25% style="text-align:center;color:gray;">"Who was <ins>the president of the US</ins> in this image? Please output segmentation mask and explain the reason."</td>
|
71 |
-
<td width=25% style="text-align:center;">"Sure, the segmentation result is [SEG]. The President of the United States in the image is <ins>President Obama.</ins>”</td>
|
72 |
-
<td width=25% style="text-align:center;color:gray">"Who was <ins>the president of the US</ins> in this image? Please output segmentation mask and explain why."</td>
|
73 |
-
<td width=25% style="text-align:center;">"Sure, [SEG]. In the image, the President of the United States is <ins>President Trump.</ins>"</td>
|
74 |
-
</tr>
|
75 |
-
|
76 |
-
<tr>
|
77 |
-
<td><img src="resources/imgs/stand_higher.jpg"></td>
|
78 |
-
<td><img src="vis_output/stand_higher.jpg"></td>
|
79 |
-
<td><img src="resources/imgs/camera_lens.jpg"></td>
|
80 |
-
<td><img src="vis_output/camera_lens.jpg"></td>
|
81 |
-
</tr>
|
82 |
-
<tr>
|
83 |
-
<td width=25% style="text-align:center;color:gray;">"What can <ins>make the woman stand higher</ins>? Please output segmentation mask and explain why."</td>
|
84 |
-
<td width=25% style="text-align:center;">"Sure, [SEG]. The woman is <ins>standing higher by using a ladder</ins>..."</td>
|
85 |
-
<td width=25% style="text-align:center;color:gray">"Can you segment <ins>the camera lens that is more suitable for photographing nearby objects</ins> in this image?"</td>
|
86 |
-
<td width=25% style="text-align:center;">"Sure, [SEG]."</td>
|
87 |
-
</tr>
|
88 |
-
|
89 |
-
<tr>
|
90 |
-
<td><img src="resources/imgs/dog_with_horn.jpg"></td>
|
91 |
-
<td><img src="vis_output/dog_with_horn.jpg"></td>
|
92 |
-
<td><img src="resources/imgs/wash_hands.jpg"></td>
|
93 |
-
<td><img src="vis_output/wash_hands.jpg"></td>
|
94 |
-
</tr>
|
95 |
-
<tr>
|
96 |
-
<td width=25% style="text-align:center;color:gray;">"Can you segment <ins>the unusual part</ins> in this image and explain why."</td>
|
97 |
-
<td width=25% style="text-align:center;">"Sure, it is [SEG]. In the image, the unusual part is <ins>the dog wearing a reindeer antler headband</ins>..."</td>
|
98 |
-
<td width=25% style="text-align:center;color:gray">"Where to <ins>wash hands</ins> in this image? Please output segmentation mask."</td>
|
99 |
-
<td width=25% style="text-align:center;">"Sure, [SEG]."</td>
|
100 |
-
</tr>
|
101 |
-
|
102 |
-
<tr>
|
103 |
-
<td><img src="resources/imgs/jackma.jpg"></td>
|
104 |
-
<td><img src="vis_output/jackma.jpg"></td>
|
105 |
-
<td><img src="resources/imgs/blackpink.jpg"></td>
|
106 |
-
<td><img src="vis_output/blackpink.jpg"></td>
|
107 |
-
</tr>
|
108 |
-
<tr>
|
109 |
-
<td width=25% style="text-align:center;color:gray;">"Can you segment <ins>the founder of Alibaba</ins> in this image and explain why?"</td>
|
110 |
-
<td width=25% style="text-align:center;">"Sure, it is [SEG]. In the image, the man wearing a red tie and sitting on one of the chairs is <ins>Jack Ma</ins>, the co-founder of Alibaba Group..."</td>
|
111 |
-
<td width=25% style="text-align:center;color:gray">"Please segment <ins>Lisa</ins> in this figure."</td>
|
112 |
-
<td width=25% style="text-align:center;">"Sure, [SEG]."</td>
|
113 |
-
</tr>
|
114 |
-
</table>
|
115 |
-
|
116 |
-
<p align="center"> <img src="resources/imgs/fig_overview.jpg" width="100%"> </p>
|
117 |
-
|
118 |
-
## News
|
119 |
-
- [x] [2023.8.30] Release three new models [LISA-7B-v1](https://huggingface.co/xinlai/LISA-7B-v1), [LISA-7B-v1-explanatory](https://huggingface.co/xinlai/LISA-7B-v1-explanatory), and [LISA-13B-llama2-v1-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v1-explanatory). Welcome to check them out!
|
120 |
-
- [x] [2023.8.23] Refactor code, and release new model [LISA-13B-llama2-v1](https://huggingface.co/xinlai/LISA-13B-llama2-v1). Welcome to check it out!
|
121 |
-
- [x] [2023.8.9] Training code is released!
|
122 |
-
- [x] [2023.8.4] [Online Demo](http://103.170.5.190:7860/) is released!
|
123 |
-
- [x] [2023.8.4] [*ReasonSeg* Dataset](https://drive.google.com/drive/folders/125mewyg5Ao6tZ3ZdJ-1-E3n04LGVELqy?usp=sharing) and the [LISA-13B-llama2-v0-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explanatory) model are released!
|
124 |
-
- [x] [2023.8.3] Inference code and the [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) model are released. Welcome to check them out!
|
125 |
-
- [x] [2023.8.2] [Paper](https://arxiv.org/pdf/2308.00692.pdf) is released and GitHub repo is created.
|
126 |
-
|
127 |
-
**LISA: Reasoning Segmentation via Large Language Model [[Paper](https://arxiv.org/abs/2308.00692)]** <br />
|
128 |
-
[Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl=zh-CN),
|
129 |
-
[Zhuotao Tian](https://scholar.google.com/citations?user=mEjhz-IAAAAJ&hl=en),
|
130 |
-
[Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl=en),
|
131 |
-
[Yanwei Li](https://scholar.google.com/citations?user=I-UCPPcAAAAJ&hl=zh-CN),
|
132 |
-
[Yuhui Yuan](https://scholar.google.com/citations?user=PzyvzksAAAAJ&hl=en),
|
133 |
-
[Shu Liu](https://scholar.google.com.hk/citations?user=BUEDUFkAAAAJ&hl=zh-CN),
|
134 |
-
[Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en)<br />
|
135 |
-
|
136 |
-
## Abstract
|
137 |
-
In this work, we propose a new segmentation task --- ***reasoning segmentation***. The task is designed to output a segmentation mask given a complex and implicit query text. We establish a benchmark comprising over one thousand image-instruction pairs, incorporating intricate reasoning and world knowledge for evaluation purposes. Finally, we present LISA: Large-language Instructed Segmentation Assistant, which inherits the language generation capabilities of the multi-modal Large Language Model (LLM) while also possessing the ability to produce segmentation masks.
|
138 |
-
For more details, please refer to the [paper](https://arxiv.org/abs/2308.00692).
|
139 |
-
|
140 |
-
## Highlights
|
141 |
-
**LISA** unlocks the new segmentation capabilities of multi-modal LLMs, and can handle cases involving:
|
142 |
-
1. complex reasoning;
|
143 |
-
2. world knowledge;
|
144 |
-
3. explanatory answers;
|
145 |
-
4. multi-turn conversation.
|
146 |
-
|
147 |
-
**LISA** also demonstrates robust zero-shot capability when trained exclusively on reasoning-free datasets. In addition, fine-tuning the model with merely 239 reasoning segmentation image-instruction pairs results in further performance enhancement.
|
148 |
-
|
149 |
-
## Experimental results
|
150 |
-
<p align="center"> <img src="resources/imgs/table1.jpg" width="80%"> </p>
|
151 |
-
|
152 |
-
## Installation
|
153 |
-
```
|
154 |
-
pip install -r requirements.txt
|
155 |
-
pip install flash-attn --no-build-isolation
|
156 |
-
```
|
157 |
-
|
158 |
-
## Training
|
159 |
-
### Training Data Preparation
|
160 |
-
The training data consists of 4 types of data:
|
161 |
-
|
162 |
-
1. Semantic segmentation datasets: [ADE20K](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip), [COCO-Stuff](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip), [Mapillary](https://www.mapillary.com/dataset/vistas), [PACO-LVIS](https://github.com/facebookresearch/paco/tree/main#dataset-setup), [PASCAL-Part](https://github.com/facebookresearch/VLPart/tree/main/datasets#pascal-part), [COCO Images](http://images.cocodataset.org/zips/train2017.zip)
|
163 |
-
|
164 |
-
Note: For COCO-Stuff, we use the annotation file stuffthingmaps_trainval2017.zip. We only use the PACO-LVIS part in PACO. COCO Images should be put into the `dataset/coco/` directory.
|
165 |
-
|
166 |
-
3. Referring segmentation datasets: [refCOCO](https://web.archive.org/web/20220413011718/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip), [refCOCO+](https://web.archive.org/web/20220413011656/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip), [refCOCOg](https://web.archive.org/web/20220413012904/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip), [refCLEF](https://web.archive.org/web/20220413011817/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refclef.zip) ([saiapr_tc-12](https://web.archive.org/web/20220515000000/http://bvisionweb1.cs.unc.edu/licheng/referit/data/images/saiapr_tc-12.zip))
|
167 |
-
|
168 |
-
Note: the original links of refCOCO series data are down, and we update them with new ones. If the download speed is super slow or unstable, we also provide a [OneDrive link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155154502_link_cuhk_edu_hk/Em5yELVBvfREodKC94nOFLoBLro_LPxsOxNV44PHRWgLcA?e=zQPjsc) to download. **You must also follow the rules that the original datasets require.**
|
169 |
-
|
170 |
-
4. Visual Question Answering dataset: [LLaVA-Instruct-150k](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json)
|
171 |
-
|
172 |
-
5. Reasoning segmentation dataset: [ReasonSeg](https://github.com/dvlab-research/LISA#dataset)
|
173 |
-
|
174 |
-
Download them from the above links, and organize them as follows.
|
175 |
-
|
176 |
-
```
|
177 |
-
├── dataset
|
178 |
-
│ ├── ade20k
|
179 |
-
│ │ ├── annotations
|
180 |
-
│ │ └── images
|
181 |
-
│ ├── coco
|
182 |
-
│ │ └── train2017
|
183 |
-
│ │ ├── 000000000009.jpg
|
184 |
-
│ │ └── ...
|
185 |
-
│ ├── cocostuff
|
186 |
-
│ │ └── train2017
|
187 |
-
│ │ ├── 000000000009.png
|
188 |
-
│ │ └── ...
|
189 |
-
│ ├── llava_dataset
|
190 |
-
│ │ └── llava_instruct_150k.json
|
191 |
-
│ ├── mapillary
|
192 |
-
│ │ ├── config_v2.0.json
|
193 |
-
│ │ ├── testing
|
194 |
-
│ │ ├── training
|
195 |
-
│ │ └── validation
|
196 |
-
│ ├── reason_seg
|
197 |
-
│ │ └── ReasonSeg
|
198 |
-
│ │ ├── train
|
199 |
-
│ │ ├── val
|
200 |
-
│ │ └── explanatory
|
201 |
-
│ ├── refer_seg
|
202 |
-
│ │ ├── images
|
203 |
-
│ │ | ├── saiapr_tc-12
|
204 |
-
│ │ | └── mscoco
|
205 |
-
│ │ | └── images
|
206 |
-
│ │ | └── train2014
|
207 |
-
│ │ ├── refclef
|
208 |
-
│ │ ├── refcoco
|
209 |
-
│ │ ├── refcoco+
|
210 |
-
│ │ └── refcocog
|
211 |
-
│ └── vlpart
|
212 |
-
│ ├── paco
|
213 |
-
│ │ └── annotations
|
214 |
-
│ └── pascal_part
|
215 |
-
│ ├── train.json
|
216 |
-
│ └── VOCdevkit
|
217 |
-
```
|
218 |
-
|
219 |
-
### Pre-trained weights
|
220 |
-
|
221 |
-
#### LLaVA
|
222 |
-
To train LISA-7B or 13B, you need to follow the [instruction](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) to merge the LLaVA delta weights. Typically, we use the final weights `LLaVA-Lightning-7B-v1-1` and `LLaVA-13B-v1-1` merged from `liuhaotian/LLaVA-Lightning-7B-delta-v1-1` and `liuhaotian/LLaVA-13b-delta-v1-1`, respectively. For Llama2, we can directly use the LLaVA full weights `liuhaotian/llava-llama-2-13b-chat-lightning-preview`.
|
223 |
-
|
224 |
-
#### SAM ViT-H weights
|
225 |
-
Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
|
226 |
-
|
227 |
-
### Training
|
228 |
-
```
|
229 |
-
deepspeed --master_port=24999 train_ds.py \
|
230 |
-
--version="PATH_TO_LLaVA" \
|
231 |
-
--dataset_dir='./dataset' \
|
232 |
-
--vision_pretrained="PATH_TO_SAM" \
|
233 |
-
--dataset="sem_seg||refer_seg||vqa||reason_seg" \
|
234 |
-
--sample_rates="9,3,3,1" \
|
235 |
-
--exp_name="lisa-7b"
|
236 |
-
```
|
237 |
-
When training is finished, to get the full model weight:
|
238 |
-
```
|
239 |
-
cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
|
240 |
-
```
|
241 |
-
|
242 |
-
### Merge LoRA Weight
|
243 |
-
Merge the LoRA weights of `pytorch_model.bin`, save the resulting model into your desired path in the Hugging Face format:
|
244 |
-
```
|
245 |
-
CUDA_VISIBLE_DEVICES="" python merge_lora_weights_and_save_hf_model.py \
|
246 |
-
--version="PATH_TO_LLaVA" \
|
247 |
-
--weight="PATH_TO_pytorch_model.bin" \
|
248 |
-
--save_path="PATH_TO_SAVED_MODEL"
|
249 |
-
```
|
250 |
-
|
251 |
-
For example:
|
252 |
-
```
|
253 |
-
CUDA_VISIBLE_DEVICES="" python3 merge_lora_weights_and_save_hf_model.py \
|
254 |
-
--version="./LLaVA/LLaVA-Lightning-7B-v1-1" \
|
255 |
-
--weight="lisa-7b/pytorch_model.bin" \
|
256 |
-
--save_path="./LISA-7B"
|
257 |
-
```
|
258 |
-
|
259 |
-
### Validation
|
260 |
-
```
|
261 |
-
deepspeed --master_port=24999 train_ds.py \
|
262 |
-
--version="PATH_TO_LISA_HF_Model_Directory" \
|
263 |
-
--dataset_dir='./dataset' \
|
264 |
-
--vision_pretrained="PATH_TO_SAM" \
|
265 |
-
--exp_name="lisa-7b" \
|
266 |
-
--eval_only
|
267 |
-
```
|
268 |
-
|
269 |
-
Note: the `v1` model is trained using both `train+val` sets, so please use the `v0` model to reproduce the validation results. (To use the `v0` models, please first checkout to the legacy version repo with `git checkout 0e26916`.)
|
270 |
-
|
271 |
-
|
272 |
-
## Inference
|
273 |
-
|
274 |
-
To chat with [LISA-13B-llama2-v1](https://huggingface.co/xinlai/LISA-13B-llama2-v1) or [LISA-13B-llama2-v1-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v1-explanatory):
|
275 |
-
(Note that `chat.py` currently does not support `v0` models (i.e., `LISA-13B-llama2-v0` and `LISA-13B-llama2-v0-explanatory`), if you want to use the `v0` models, please first checkout to the legacy version repo `git checkout 0e26916`.)
|
276 |
-
```
|
277 |
-
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1'
|
278 |
-
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1-explanatory'
|
279 |
-
```
|
280 |
-
To use `bf16` or `fp16` data type for inference:
|
281 |
-
```
|
282 |
-
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='bf16'
|
283 |
-
```
|
284 |
-
To use `8bit` or `4bit` data type for inference (this enables running 13B model on a single 24G or 12G GPU at some cost of generation quality):
|
285 |
-
```
|
286 |
-
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='fp16' --load_in_8bit
|
287 |
-
CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='fp16' --load_in_4bit
|
288 |
-
```
|
289 |
-
Hint: for 13B model, 16-bit inference consumes 30G VRAM with a single GPU, 8-bit inference consumes 16G, and 4-bit inference consumes 9G.
|
290 |
-
|
291 |
-
After that, input the text prompt and then the image path. For example,
|
292 |
-
```
|
293 |
-
- Please input your prompt: Where can the driver see the car speed in this image? Please output segmentation mask.
|
294 |
-
- Please input the image path: imgs/example1.jpg
|
295 |
-
|
296 |
-
- Please input your prompt: Can you segment the food that tastes spicy and hot?
|
297 |
-
- Please input the image path: imgs/example2.jpg
|
298 |
-
```
|
299 |
-
The results should be like:
|
300 |
-
<p align="center"> <img src="resources/imgs/example1.jpg" width="22%"> <img src="vis_output/example1_masked_img_0.jpg" width="22%"> <img src="resources/imgs/example2.jpg" width="25%"> <img src="vis_output/example2_masked_img_0.jpg" width="25%"> </p>
|
301 |
-
|
302 |
-
## Deployment
|
303 |
-
```
|
304 |
-
CUDA_VISIBLE_DEVICES=0 python app.py --version='xinlai/LISA-13B-llama2-v1 --load_in_4bit'
|
305 |
-
CUDA_VISIBLE_DEVICES=0 python app.py --version='xinlai/LISA-13B-llama2-v1-explanatory --load_in_4bit'
|
306 |
-
```
|
307 |
-
By default, we use 4-bit quantization. Feel free to delete the `--load_in_4bit` argument for 16-bit inference or replace it with `--load_in_8bit` argument for 8-bit inference.
|
308 |
-
|
309 |
-
|
310 |
-
## Dataset
|
311 |
-
In ReasonSeg, we have collected 1218 images (239 train, 200 val, and 779 test). The training and validation sets can be download from <a href="https://drive.google.com/drive/folders/125mewyg5Ao6tZ3ZdJ-1-E3n04LGVELqy?usp=sharing">**this link**</a>.
|
312 |
-
|
313 |
-
Each image is provided with an annotation JSON file:
|
314 |
-
```
|
315 |
-
image_1.jpg, image_1.json
|
316 |
-
image_2.jpg, image_2.json
|
317 |
-
...
|
318 |
-
image_n.jpg, image_n.json
|
319 |
-
```
|
320 |
-
Important keys contained in JSON files:
|
321 |
-
```
|
322 |
-
- "text": text instructions.
|
323 |
-
- "is_sentence": whether the text instructions are long sentences.
|
324 |
-
- "shapes": target polygons.
|
325 |
-
```
|
326 |
|
327 |
-
|
328 |
|
329 |
-
|
330 |
-
|
331 |
-
python3 utils/data_processing.py
|
332 |
-
```
|
333 |
-
|
334 |
-
Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have **more than one instructions (but fewer than six)** in the "text" field. During training, users may randomly select one as the text query to obtain a better model.
|
335 |
-
|
336 |
-
|
337 |
-
## Citation
|
338 |
-
If you find this project useful in your research, please consider citing:
|
339 |
-
|
340 |
-
```
|
341 |
-
@article{lai2023lisa,
|
342 |
-
title={LISA: Reasoning Segmentation via Large Language Model},
|
343 |
-
author={Lai, Xin and Tian, Zhuotao and Chen, Yukang and Li, Yanwei and Yuan, Yuhui and Liu, Shu and Jia, Jiaya},
|
344 |
-
journal={arXiv preprint arXiv:2308.00692},
|
345 |
-
year={2023}
|
346 |
-
}
|
347 |
-
@article{yang2023improved,
|
348 |
-
title={An Improved Baseline for Reasoning Segmentation with Large Language Model},
|
349 |
-
author={Yang, Senqiao and Qu, Tianyuan and Lai, Xin and Tian, Zhuotao and Peng, Bohao and Liu, Shu and Jia, Jiaya},
|
350 |
-
journal={arXiv preprint arXiv:2312.17240},
|
351 |
-
year={2023}
|
352 |
-
}
|
353 |
-
```
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
- placeholders images (error, 'no output segmentation') from Muhammad Khaleeq (https://www.vecteezy.com/members/iyikon)
|
|
|
1 |
---
|
2 |
+
title: lisa + gradio + fastapi + ZeroGPU
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.37.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
---
|
11 |
|
12 |
+
# LISA (Reasoning Segmentation via Large Language Model) on cuda, now with huggingface ZeroGPU support!
|
13 |
+
|
14 |
+
## Exec jupyter on the remote server with port forwarding on localhost
|
15 |
|
16 |
1. checkout repo, install venv with jupyter
|
17 |
2. port forwarding in localhost wiht private key: `ssh -i /path/to/private_key name@endpoint.com -L 8889:localhost:8889 -N -f`
|
18 |
3. start the jupyter-lab server
|
19 |
4. connect to page in localhost
|
20 |
|
21 |
+
## Commands to work on remote virtual machines (e.g. SaturnCloud) after clone and git lfs install
|
22 |
+
|
23 |
```bash
|
24 |
cd ~/workspace/lisa-on-cuda/
|
25 |
rm -rf lisa_venv
|
|
|
43 |
- installed jupyterlab dependencies from requirements_jupyter.txt
|
44 |
- installed dependencies from requirements.txt
|
45 |
|
46 |
+
## Hardware requirements for local usage
|
47 |
+
|
48 |
- an nvidia gpu with 10 or 12GB of memory (a T4 should suffice)
|
49 |
- at least 16GB of system ram
|
50 |
|
51 |
+
## Hardware requirements on huggingface ZeroGPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
Right now (July 2024) huggingface let use ZeroGPU Nvidia A100 GPUs.
|
54 |
|
55 |
+
[![Gradio](https://img.shields.io/badge/Gradio-Online%20Demo-blue)](http://103.170.5.190:7860/)
|
56 |
+
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/openxlab-app/LISA)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
See [LISA](https://github.com/dvlab-research/LISA) for details on the original project.
|
59 |
+
Note that the authors don't keep the project updated anymore.
|
|
lisa_on_cuda/app/main.py → app.py
RENAMED
@@ -1,21 +1,25 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
import sys
|
|
|
4 |
import gradio as gr
|
|
|
5 |
from fastapi import FastAPI
|
6 |
from fastapi.staticfiles import StaticFiles
|
7 |
from fastapi.templating import Jinja2Templates
|
|
|
8 |
|
9 |
-
from
|
10 |
-
from
|
11 |
-
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
CUSTOM_GRADIO_PATH = "/"
|
16 |
app = FastAPI(title="lisa_app", version="1.0")
|
17 |
app.include_router(routes.router)
|
18 |
|
|
|
19 |
os.makedirs(utils.FASTAPI_STATIC, exist_ok=True)
|
20 |
app.mount("/static", StaticFiles(directory=utils.FASTAPI_STATIC), name="static")
|
21 |
templates = Jinja2Templates(directory="templates")
|
@@ -24,9 +28,17 @@ templates = Jinja2Templates(directory="templates")
|
|
24 |
app_helpers.app_logger.info(f"sys.argv:{sys.argv}.")
|
25 |
args = app_helpers.parse_args([])
|
26 |
app_helpers.app_logger.info(f"prepared default arguments:{args}.")
|
27 |
-
inference_fn = app_helpers.get_inference_model_by_args(args)
|
28 |
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
|
29 |
io = app_helpers.get_gradio_interface(inference_fn)
|
30 |
app_helpers.app_logger.info("created gradio interface")
|
31 |
app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
|
32 |
app_helpers.app_logger.info("mounted gradio app within fastapi")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
import sys
|
4 |
+
|
5 |
import gradio as gr
|
6 |
+
import uvicorn
|
7 |
from fastapi import FastAPI
|
8 |
from fastapi.staticfiles import StaticFiles
|
9 |
from fastapi.templating import Jinja2Templates
|
10 |
+
from spaces import GPU as SPACES_GPU
|
11 |
|
12 |
+
from lisa_on_cuda import routes
|
13 |
+
from lisa_on_cuda.utils import app_helpers, session_logger, utils
|
|
|
14 |
|
15 |
+
LOGLEVEL = os.getenv('LOGLEVEL', 'INFO').upper()
|
16 |
+
session_logger.change_logging(LOGLEVEL)
|
17 |
|
18 |
CUSTOM_GRADIO_PATH = "/"
|
19 |
app = FastAPI(title="lisa_app", version="1.0")
|
20 |
app.include_router(routes.router)
|
21 |
|
22 |
+
|
23 |
os.makedirs(utils.FASTAPI_STATIC, exist_ok=True)
|
24 |
app.mount("/static", StaticFiles(directory=utils.FASTAPI_STATIC), name="static")
|
25 |
templates = Jinja2Templates(directory="templates")
|
|
|
28 |
app_helpers.app_logger.info(f"sys.argv:{sys.argv}.")
|
29 |
args = app_helpers.parse_args([])
|
30 |
app_helpers.app_logger.info(f"prepared default arguments:{args}.")
|
31 |
+
inference_fn = app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU)
|
32 |
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
|
33 |
io = app_helpers.get_gradio_interface(inference_fn)
|
34 |
app_helpers.app_logger.info("created gradio interface")
|
35 |
app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
|
36 |
app_helpers.app_logger.info("mounted gradio app within fastapi")
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
try:
|
41 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
42 |
+
except Exception as ex:
|
43 |
+
logging.error(f"ex_:{ex}.")
|
44 |
+
raise ex
|
lisa_on_cuda/app/__init__.py
DELETED
File without changes
|
lisa_on_cuda/app/chat.py
DELETED
@@ -1,200 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
|
5 |
-
import cv2
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
9 |
-
|
10 |
-
from lisa_on_cuda.LISA import LISAForCausalLM
|
11 |
-
from lisa_on_cuda.llava import conversation as conversation_lib
|
12 |
-
from lisa_on_cuda.llava.mm_utils import tokenizer_image_token
|
13 |
-
from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
|
14 |
-
from ..utils import app_helpers, utils
|
15 |
-
|
16 |
-
|
17 |
-
def main(args):
|
18 |
-
args = app_helpers.parse_args(args)
|
19 |
-
os.makedirs(args.vis_save_path, exist_ok=True)
|
20 |
-
|
21 |
-
# Create model
|
22 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
23 |
-
args.version,
|
24 |
-
cache_dir=None,
|
25 |
-
model_max_length=args.model_max_length,
|
26 |
-
padding_side="right",
|
27 |
-
use_fast=False,
|
28 |
-
)
|
29 |
-
tokenizer.pad_token = tokenizer.unk_token
|
30 |
-
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
31 |
-
|
32 |
-
torch_dtype = change_torch_dtype_by_precision(args.precision)
|
33 |
-
|
34 |
-
kwargs = {"torch_dtype": torch_dtype}
|
35 |
-
if args.load_in_4bit:
|
36 |
-
kwargs.update(
|
37 |
-
{
|
38 |
-
"torch_dtype": torch.half,
|
39 |
-
"load_in_4bit": True,
|
40 |
-
"quantization_config": BitsAndBytesConfig(
|
41 |
-
load_in_4bit=True,
|
42 |
-
bnb_4bit_compute_dtype=torch.float16,
|
43 |
-
bnb_4bit_use_double_quant=True,
|
44 |
-
bnb_4bit_quant_type="nf4",
|
45 |
-
llm_int8_skip_modules=["visual_model"],
|
46 |
-
),
|
47 |
-
}
|
48 |
-
)
|
49 |
-
elif args.load_in_8bit:
|
50 |
-
kwargs.update(
|
51 |
-
{
|
52 |
-
"torch_dtype": torch.half,
|
53 |
-
"quantization_config": BitsAndBytesConfig(
|
54 |
-
llm_int8_skip_modules=["visual_model"],
|
55 |
-
load_in_8bit=True,
|
56 |
-
),
|
57 |
-
}
|
58 |
-
)
|
59 |
-
|
60 |
-
model = LISAForCausalLM.from_pretrained(
|
61 |
-
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
|
62 |
-
)
|
63 |
-
|
64 |
-
model.config.eos_token_id = tokenizer.eos_token_id
|
65 |
-
model.config.bos_token_id = tokenizer.bos_token_id
|
66 |
-
model.config.pad_token_id = tokenizer.pad_token_id
|
67 |
-
|
68 |
-
model.get_model().initialize_vision_modules(model.get_model().config)
|
69 |
-
vision_tower = model.get_model().get_vision_tower()
|
70 |
-
vision_tower.to(dtype=torch_dtype)
|
71 |
-
|
72 |
-
if args.precision == "bf16":
|
73 |
-
model = model.bfloat16().cuda()
|
74 |
-
elif (
|
75 |
-
args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
|
76 |
-
):
|
77 |
-
vision_tower = model.get_model().get_vision_tower()
|
78 |
-
model.model.vision_tower = None
|
79 |
-
import deepspeed
|
80 |
-
|
81 |
-
model_engine = deepspeed.init_inference(
|
82 |
-
model=model,
|
83 |
-
dtype=torch.half,
|
84 |
-
replace_with_kernel_inject=True,
|
85 |
-
replace_method="auto",
|
86 |
-
)
|
87 |
-
model = model_engine.module
|
88 |
-
model.model.vision_tower = vision_tower.half().cuda()
|
89 |
-
elif args.precision == "fp32":
|
90 |
-
model = model.float().cuda()
|
91 |
-
|
92 |
-
vision_tower = model.get_model().get_vision_tower()
|
93 |
-
vision_tower.to(device=args.local_rank)
|
94 |
-
|
95 |
-
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
96 |
-
transform = ResizeLongestSide(args.image_size)
|
97 |
-
|
98 |
-
model.eval()
|
99 |
-
|
100 |
-
while True:
|
101 |
-
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
102 |
-
conv.messages = []
|
103 |
-
|
104 |
-
prompt = input("Please input your prompt: ")
|
105 |
-
prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
|
106 |
-
if args.use_mm_start_end:
|
107 |
-
replace_token = (
|
108 |
-
utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
|
109 |
-
)
|
110 |
-
prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
|
111 |
-
|
112 |
-
conv.append_message(conv.roles[0], prompt)
|
113 |
-
conv.append_message(conv.roles[1], "")
|
114 |
-
prompt = conv.get_prompt()
|
115 |
-
|
116 |
-
image_path = input("Please input the image path: ")
|
117 |
-
if not os.path.exists(image_path):
|
118 |
-
print("File not found in {}".format(image_path))
|
119 |
-
continue
|
120 |
-
|
121 |
-
image_np = cv2.imread(image_path)
|
122 |
-
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
123 |
-
original_size_list = [image_np.shape[:2]]
|
124 |
-
|
125 |
-
image_clip = (
|
126 |
-
clip_image_processor.preprocess(image_np, return_tensors="pt")[
|
127 |
-
"pixel_values"
|
128 |
-
][0]
|
129 |
-
.unsqueeze(0)
|
130 |
-
.cuda()
|
131 |
-
)
|
132 |
-
logging.info(f"image_clip type: {type(image_clip)}.")
|
133 |
-
image_clip = app_helpers.set_image_precision_by_args(image_clip, args.precision)
|
134 |
-
|
135 |
-
image = transform.apply_image(image_np)
|
136 |
-
resize_list = [image.shape[:2]]
|
137 |
-
|
138 |
-
image = (
|
139 |
-
app_helpers.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
140 |
-
.unsqueeze(0)
|
141 |
-
.cuda()
|
142 |
-
)
|
143 |
-
logging.info(f"image_clip type: {type(image_clip)}.")
|
144 |
-
image = app_helpers.set_image_precision_by_args(image, args.precision)
|
145 |
-
|
146 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
147 |
-
input_ids = input_ids.unsqueeze(0).cuda()
|
148 |
-
|
149 |
-
output_ids, pred_masks = model.evaluate(
|
150 |
-
image_clip,
|
151 |
-
image,
|
152 |
-
input_ids,
|
153 |
-
resize_list,
|
154 |
-
original_size_list,
|
155 |
-
max_new_tokens=512,
|
156 |
-
tokenizer=tokenizer,
|
157 |
-
)
|
158 |
-
output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
|
159 |
-
|
160 |
-
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
161 |
-
text_output = text_output.replace("\n", "").replace(" ", " ")
|
162 |
-
logging.info(f"text_output: {text_output}.")
|
163 |
-
|
164 |
-
for i, pred_mask in enumerate(pred_masks):
|
165 |
-
if pred_mask.shape[0] == 0:
|
166 |
-
continue
|
167 |
-
|
168 |
-
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
169 |
-
pred_mask = pred_mask > 0
|
170 |
-
|
171 |
-
save_path = "{}/{}_mask_{}.jpg".format(
|
172 |
-
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
173 |
-
)
|
174 |
-
cv2.imwrite(save_path, pred_mask * 100)
|
175 |
-
print("{} has been saved.".format(save_path))
|
176 |
-
|
177 |
-
save_path = "{}/{}_masked_img_{}.jpg".format(
|
178 |
-
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
179 |
-
)
|
180 |
-
save_img = image_np.copy()
|
181 |
-
save_img[pred_mask] = (
|
182 |
-
image_np * 0.5
|
183 |
-
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
184 |
-
)[pred_mask]
|
185 |
-
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
|
186 |
-
cv2.imwrite(save_path, save_img)
|
187 |
-
print("{} has been saved.".format(save_path))
|
188 |
-
|
189 |
-
|
190 |
-
def change_torch_dtype_by_precision(precision):
|
191 |
-
torch_dtype = torch.float32
|
192 |
-
if precision == "bf16":
|
193 |
-
torch_dtype = torch.bfloat16
|
194 |
-
elif precision == "fp16":
|
195 |
-
torch_dtype = torch.half
|
196 |
-
return torch_dtype
|
197 |
-
|
198 |
-
|
199 |
-
if __name__ == "__main__":
|
200 |
-
main(sys.argv[1:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/app/merge_lora_weights_and_save_hf_model.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import glob
|
3 |
-
import os
|
4 |
-
import sys
|
5 |
-
|
6 |
-
import cv2
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import transformers
|
11 |
-
from peft import LoraConfig, get_peft_model
|
12 |
-
from transformers import AutoTokenizer
|
13 |
-
|
14 |
-
from lisa_on_cuda.LISA import LISAForCausalLM
|
15 |
-
from ..utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
|
16 |
-
|
17 |
-
|
18 |
-
def parse_args(args):
|
19 |
-
parser = argparse.ArgumentParser(
|
20 |
-
description="merge lora weights and save model with hf format"
|
21 |
-
)
|
22 |
-
parser.add_argument(
|
23 |
-
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
|
24 |
-
)
|
25 |
-
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
26 |
-
parser.add_argument(
|
27 |
-
"--precision",
|
28 |
-
default="bf16",
|
29 |
-
type=str,
|
30 |
-
choices=["fp32", "bf16", "fp16"],
|
31 |
-
help="precision for inference",
|
32 |
-
)
|
33 |
-
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
|
34 |
-
parser.add_argument("--out_dim", default=256, type=int)
|
35 |
-
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
36 |
-
parser.add_argument("--model_max_length", default=512, type=int)
|
37 |
-
parser.add_argument(
|
38 |
-
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
39 |
-
)
|
40 |
-
parser.add_argument("--lora_r", default=8, type=int)
|
41 |
-
parser.add_argument("--lora_alpha", default=16, type=int)
|
42 |
-
parser.add_argument("--lora_dropout", default=0.05, type=float)
|
43 |
-
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
|
44 |
-
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
45 |
-
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
|
46 |
-
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
47 |
-
parser.add_argument(
|
48 |
-
"--conv_type",
|
49 |
-
default="llava_v1",
|
50 |
-
type=str,
|
51 |
-
choices=["llava_v1", "llava_llama_2"],
|
52 |
-
)
|
53 |
-
parser.add_argument("--weight", default="", type=str, required=True)
|
54 |
-
parser.add_argument("--save_path", default="./lisa_model", type=str, required=True)
|
55 |
-
return parser.parse_args(args)
|
56 |
-
|
57 |
-
|
58 |
-
def main(args):
|
59 |
-
args = parse_args(args)
|
60 |
-
os.makedirs(args.vis_save_path, exist_ok=True)
|
61 |
-
|
62 |
-
# Create model
|
63 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
64 |
-
args.version,
|
65 |
-
cache_dir=None,
|
66 |
-
model_max_length=args.model_max_length,
|
67 |
-
padding_side="right",
|
68 |
-
use_fast=False,
|
69 |
-
)
|
70 |
-
tokenizer.pad_token = tokenizer.unk_token
|
71 |
-
num_added_tokens = tokenizer.add_tokens("[SEG]")
|
72 |
-
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
73 |
-
|
74 |
-
if args.use_mm_start_end:
|
75 |
-
tokenizer.add_tokens(
|
76 |
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
77 |
-
)
|
78 |
-
|
79 |
-
model_args = {
|
80 |
-
"train_mask_decoder": args.train_mask_decoder,
|
81 |
-
"out_dim": args.out_dim,
|
82 |
-
"seg_token_idx": args.seg_token_idx,
|
83 |
-
"vision_tower": args.vision_tower,
|
84 |
-
}
|
85 |
-
|
86 |
-
torch_dtype = torch.float32
|
87 |
-
if args.precision == "bf16":
|
88 |
-
torch_dtype = torch.bfloat16
|
89 |
-
elif args.precision == "fp16":
|
90 |
-
torch_dtype = torch.half
|
91 |
-
model = LISAForCausalLM.from_pretrained(
|
92 |
-
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
|
93 |
-
)
|
94 |
-
model.config.eos_token_id = tokenizer.eos_token_id
|
95 |
-
model.config.bos_token_id = tokenizer.bos_token_id
|
96 |
-
model.config.pad_token_id = tokenizer.pad_token_id
|
97 |
-
|
98 |
-
model.get_model().initialize_vision_modules(model.get_model().config)
|
99 |
-
vision_tower = model.get_model().get_vision_tower()
|
100 |
-
vision_tower.to(dtype=torch_dtype)
|
101 |
-
model.get_model().initialize_lisa_modules(model.get_model().config)
|
102 |
-
|
103 |
-
lora_r = args.lora_r
|
104 |
-
if lora_r > 0:
|
105 |
-
|
106 |
-
def find_linear_layers(model, lora_target_modules):
|
107 |
-
cls = torch.nn.Linear
|
108 |
-
lora_module_names = set()
|
109 |
-
for name, module in model.named_modules():
|
110 |
-
if (
|
111 |
-
isinstance(module, cls)
|
112 |
-
and all(
|
113 |
-
[
|
114 |
-
x not in name
|
115 |
-
for x in [
|
116 |
-
"visual_model",
|
117 |
-
"vision_tower",
|
118 |
-
"mm_projector",
|
119 |
-
"text_hidden_fcs",
|
120 |
-
]
|
121 |
-
]
|
122 |
-
)
|
123 |
-
and any([x in name for x in lora_target_modules])
|
124 |
-
):
|
125 |
-
lora_module_names.add(name)
|
126 |
-
return sorted(list(lora_module_names))
|
127 |
-
|
128 |
-
lora_alpha = args.lora_alpha
|
129 |
-
lora_dropout = args.lora_dropout
|
130 |
-
lora_target_modules = find_linear_layers(
|
131 |
-
model, args.lora_target_modules.split(",")
|
132 |
-
)
|
133 |
-
lora_config = LoraConfig(
|
134 |
-
r=lora_r,
|
135 |
-
lora_alpha=lora_alpha,
|
136 |
-
target_modules=lora_target_modules,
|
137 |
-
lora_dropout=lora_dropout,
|
138 |
-
bias="none",
|
139 |
-
task_type="CAUSAL_LM",
|
140 |
-
)
|
141 |
-
model = get_peft_model(model, lora_config)
|
142 |
-
model.print_trainable_parameters()
|
143 |
-
|
144 |
-
model.resize_token_embeddings(len(tokenizer))
|
145 |
-
|
146 |
-
state_dict = torch.load(args.weight, map_location="cpu")
|
147 |
-
model.load_state_dict(state_dict, strict=True)
|
148 |
-
|
149 |
-
model = model.merge_and_unload()
|
150 |
-
state_dict = {}
|
151 |
-
for k, v in model.state_dict().items():
|
152 |
-
if "vision_tower" not in k:
|
153 |
-
state_dict[k] = v
|
154 |
-
model.save_pretrained(args.save_path, state_dict=state_dict)
|
155 |
-
tokenizer.save_pretrained(args.save_path)
|
156 |
-
|
157 |
-
|
158 |
-
if __name__ == "__main__":
|
159 |
-
main(sys.argv[1:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/app/train_ds.py
DELETED
@@ -1,584 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import shutil
|
4 |
-
import sys
|
5 |
-
import time
|
6 |
-
from functools import partial
|
7 |
-
|
8 |
-
import deepspeed
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
import tqdm
|
12 |
-
import transformers
|
13 |
-
from peft import LoraConfig, get_peft_model
|
14 |
-
from torch.utils.tensorboard import SummaryWriter
|
15 |
-
|
16 |
-
from lisa_on_cuda.LISA import LISAForCausalLM
|
17 |
-
from lisa_on_cuda.llava import conversation as conversation_lib
|
18 |
-
from ..utils.dataset import HybridDataset, ValDataset, collate_fn
|
19 |
-
from ..utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
20 |
-
AverageMeter, ProgressMeter, Summary, dict_to_cuda,
|
21 |
-
intersectionAndUnionGPU)
|
22 |
-
|
23 |
-
|
24 |
-
def parse_args(args):
|
25 |
-
parser = argparse.ArgumentParser(description="LISA Model Training")
|
26 |
-
parser.add_argument("--local_rank", default=0, type=int, help="node rank")
|
27 |
-
parser.add_argument(
|
28 |
-
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
|
29 |
-
)
|
30 |
-
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
31 |
-
parser.add_argument(
|
32 |
-
"--precision",
|
33 |
-
default="bf16",
|
34 |
-
type=str,
|
35 |
-
choices=["fp32", "bf16", "fp16"],
|
36 |
-
help="precision for inference",
|
37 |
-
)
|
38 |
-
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
39 |
-
parser.add_argument("--model_max_length", default=512, type=int)
|
40 |
-
parser.add_argument("--lora_r", default=8, type=int)
|
41 |
-
parser.add_argument(
|
42 |
-
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
43 |
-
)
|
44 |
-
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
45 |
-
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
46 |
-
|
47 |
-
parser.add_argument(
|
48 |
-
"--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
|
49 |
-
)
|
50 |
-
parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
|
51 |
-
parser.add_argument(
|
52 |
-
"--sem_seg_data",
|
53 |
-
default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
|
54 |
-
type=str,
|
55 |
-
)
|
56 |
-
parser.add_argument(
|
57 |
-
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
|
58 |
-
)
|
59 |
-
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
|
60 |
-
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
|
61 |
-
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
|
62 |
-
parser.add_argument("--dataset_dir", default="./dataset", type=str)
|
63 |
-
parser.add_argument("--log_base_dir", default="./runs", type=str)
|
64 |
-
parser.add_argument("--exp_name", default="lisa", type=str)
|
65 |
-
parser.add_argument("--epochs", default=10, type=int)
|
66 |
-
parser.add_argument("--steps_per_epoch", default=500, type=int)
|
67 |
-
parser.add_argument(
|
68 |
-
"--batch_size", default=2, type=int, help="batch size per device per step"
|
69 |
-
)
|
70 |
-
parser.add_argument(
|
71 |
-
"--grad_accumulation_steps",
|
72 |
-
default=10,
|
73 |
-
type=int,
|
74 |
-
)
|
75 |
-
parser.add_argument("--val_batch_size", default=1, type=int)
|
76 |
-
parser.add_argument("--workers", default=4, type=int)
|
77 |
-
parser.add_argument("--lr", default=0.0003, type=float)
|
78 |
-
parser.add_argument("--ce_loss_weight", default=1.0, type=float)
|
79 |
-
parser.add_argument("--dice_loss_weight", default=0.5, type=float)
|
80 |
-
parser.add_argument("--bce_loss_weight", default=2.0, type=float)
|
81 |
-
parser.add_argument("--lora_alpha", default=16, type=int)
|
82 |
-
parser.add_argument("--lora_dropout", default=0.05, type=float)
|
83 |
-
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
|
84 |
-
parser.add_argument("--explanatory", default=0.1, type=float)
|
85 |
-
parser.add_argument("--beta1", default=0.9, type=float)
|
86 |
-
parser.add_argument("--beta2", default=0.95, type=float)
|
87 |
-
parser.add_argument("--num_classes_per_sample", default=3, type=int)
|
88 |
-
parser.add_argument("--exclude_val", action="store_true", default=False)
|
89 |
-
parser.add_argument("--no_eval", action="store_true", default=False)
|
90 |
-
parser.add_argument("--eval_only", action="store_true", default=False)
|
91 |
-
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
|
92 |
-
parser.add_argument("--out_dim", default=256, type=int)
|
93 |
-
parser.add_argument("--resume", default="", type=str)
|
94 |
-
parser.add_argument("--print_freq", default=1, type=int)
|
95 |
-
parser.add_argument("--start_epoch", default=0, type=int)
|
96 |
-
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
|
97 |
-
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
|
98 |
-
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
99 |
-
parser.add_argument("--auto_resume", action="store_true", default=True)
|
100 |
-
parser.add_argument(
|
101 |
-
"--conv_type",
|
102 |
-
default="llava_v1",
|
103 |
-
type=str,
|
104 |
-
choices=["llava_v1", "llava_llama_2"],
|
105 |
-
)
|
106 |
-
return parser.parse_args(args)
|
107 |
-
|
108 |
-
|
109 |
-
def main(args):
|
110 |
-
args = parse_args(args)
|
111 |
-
args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
|
112 |
-
if args.local_rank == 0:
|
113 |
-
os.makedirs(args.log_dir, exist_ok=True)
|
114 |
-
writer = SummaryWriter(args.log_dir)
|
115 |
-
else:
|
116 |
-
writer = None
|
117 |
-
|
118 |
-
# Create model
|
119 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
120 |
-
args.version,
|
121 |
-
cache_dir=None,
|
122 |
-
model_max_length=args.model_max_length,
|
123 |
-
padding_side="right",
|
124 |
-
use_fast=False,
|
125 |
-
)
|
126 |
-
tokenizer.pad_token = tokenizer.unk_token
|
127 |
-
num_added_tokens = tokenizer.add_tokens("[SEG]")
|
128 |
-
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
129 |
-
|
130 |
-
if args.use_mm_start_end:
|
131 |
-
tokenizer.add_tokens(
|
132 |
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
133 |
-
)
|
134 |
-
|
135 |
-
model_args = {
|
136 |
-
"train_mask_decoder": args.train_mask_decoder,
|
137 |
-
"out_dim": args.out_dim,
|
138 |
-
"ce_loss_weight": args.ce_loss_weight,
|
139 |
-
"dice_loss_weight": args.dice_loss_weight,
|
140 |
-
"bce_loss_weight": args.bce_loss_weight,
|
141 |
-
"seg_token_idx": args.seg_token_idx,
|
142 |
-
"vision_pretrained": args.vision_pretrained,
|
143 |
-
"vision_tower": args.vision_tower,
|
144 |
-
"use_mm_start_end": args.use_mm_start_end,
|
145 |
-
}
|
146 |
-
torch_dtype = torch.float32
|
147 |
-
if args.precision == "bf16":
|
148 |
-
torch_dtype = torch.bfloat16
|
149 |
-
elif args.precision == "fp16":
|
150 |
-
torch_dtype = torch.half
|
151 |
-
model = LISAForCausalLM.from_pretrained(
|
152 |
-
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
|
153 |
-
)
|
154 |
-
model.config.eos_token_id = tokenizer.eos_token_id
|
155 |
-
model.config.bos_token_id = tokenizer.bos_token_id
|
156 |
-
model.config.pad_token_id = tokenizer.pad_token_id
|
157 |
-
|
158 |
-
model.enable_input_require_grads()
|
159 |
-
model.gradient_checkpointing_enable()
|
160 |
-
|
161 |
-
model.get_model().initialize_vision_modules(model.get_model().config)
|
162 |
-
vision_tower = model.get_model().get_vision_tower()
|
163 |
-
vision_tower.to(dtype=torch_dtype, device=args.local_rank)
|
164 |
-
if not args.eval_only:
|
165 |
-
model.get_model().initialize_lisa_modules(model.get_model().config)
|
166 |
-
|
167 |
-
for p in vision_tower.parameters():
|
168 |
-
p.requires_grad = False
|
169 |
-
for p in model.get_model().mm_projector.parameters():
|
170 |
-
p.requires_grad = False
|
171 |
-
|
172 |
-
conversation_lib.default_conversation = conversation_lib.conv_templates[
|
173 |
-
args.conv_type
|
174 |
-
]
|
175 |
-
|
176 |
-
lora_r = args.lora_r
|
177 |
-
if lora_r > 0:
|
178 |
-
|
179 |
-
def find_linear_layers(model, lora_target_modules):
|
180 |
-
cls = torch.nn.Linear
|
181 |
-
lora_module_names = set()
|
182 |
-
for name, module in model.named_modules():
|
183 |
-
if (
|
184 |
-
isinstance(module, cls)
|
185 |
-
and all(
|
186 |
-
[
|
187 |
-
x not in name
|
188 |
-
for x in [
|
189 |
-
"visual_model",
|
190 |
-
"vision_tower",
|
191 |
-
"mm_projector",
|
192 |
-
"text_hidden_fcs",
|
193 |
-
]
|
194 |
-
]
|
195 |
-
)
|
196 |
-
and any([x in name for x in lora_target_modules])
|
197 |
-
):
|
198 |
-
lora_module_names.add(name)
|
199 |
-
return sorted(list(lora_module_names))
|
200 |
-
|
201 |
-
lora_alpha = args.lora_alpha
|
202 |
-
lora_dropout = args.lora_dropout
|
203 |
-
lora_target_modules = find_linear_layers(
|
204 |
-
model, args.lora_target_modules.split(",")
|
205 |
-
)
|
206 |
-
lora_config = LoraConfig(
|
207 |
-
r=lora_r,
|
208 |
-
lora_alpha=lora_alpha,
|
209 |
-
target_modules=lora_target_modules,
|
210 |
-
lora_dropout=lora_dropout,
|
211 |
-
bias="none",
|
212 |
-
task_type="CAUSAL_LM",
|
213 |
-
)
|
214 |
-
model = get_peft_model(model, lora_config)
|
215 |
-
model.print_trainable_parameters()
|
216 |
-
|
217 |
-
model.resize_token_embeddings(len(tokenizer))
|
218 |
-
|
219 |
-
# make text_hidden_fcs, mask_decoder, lm_head, embed_tokens trainable
|
220 |
-
for n, p in model.named_parameters():
|
221 |
-
if any(
|
222 |
-
[
|
223 |
-
x in n
|
224 |
-
for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
|
225 |
-
]
|
226 |
-
):
|
227 |
-
print("n: ", n, "p.shape: ", p.shape)
|
228 |
-
p.requires_grad = True
|
229 |
-
|
230 |
-
world_size = torch.cuda.device_count()
|
231 |
-
args.distributed = world_size > 1
|
232 |
-
train_dataset = HybridDataset(
|
233 |
-
args.dataset_dir,
|
234 |
-
tokenizer,
|
235 |
-
args.vision_tower,
|
236 |
-
samples_per_epoch=args.batch_size
|
237 |
-
* args.grad_accumulation_steps
|
238 |
-
* args.steps_per_epoch
|
239 |
-
* world_size,
|
240 |
-
precision=args.precision,
|
241 |
-
image_size=args.image_size,
|
242 |
-
num_classes_per_sample=args.num_classes_per_sample,
|
243 |
-
exclude_val=args.exclude_val,
|
244 |
-
dataset=args.dataset,
|
245 |
-
sample_rate=[float(x) for x in args.sample_rates.split(",")],
|
246 |
-
sem_seg_data=args.sem_seg_data,
|
247 |
-
refer_seg_data=args.refer_seg_data,
|
248 |
-
vqa_data=args.vqa_data,
|
249 |
-
reason_seg_data=args.reason_seg_data,
|
250 |
-
explanatory=args.explanatory,
|
251 |
-
)
|
252 |
-
|
253 |
-
if args.no_eval == False:
|
254 |
-
val_dataset = ValDataset(
|
255 |
-
args.dataset_dir,
|
256 |
-
tokenizer,
|
257 |
-
args.vision_tower,
|
258 |
-
args.val_dataset,
|
259 |
-
args.image_size,
|
260 |
-
)
|
261 |
-
print(
|
262 |
-
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
|
263 |
-
)
|
264 |
-
else:
|
265 |
-
val_dataset = None
|
266 |
-
print(f"Training with {len(train_dataset)} examples.")
|
267 |
-
|
268 |
-
ds_config = {
|
269 |
-
"train_micro_batch_size_per_gpu": args.batch_size,
|
270 |
-
"gradient_accumulation_steps": args.grad_accumulation_steps,
|
271 |
-
"optimizer": {
|
272 |
-
"type": "AdamW",
|
273 |
-
"params": {
|
274 |
-
"lr": args.lr,
|
275 |
-
"weight_decay": 0.0,
|
276 |
-
"betas": (args.beta1, args.beta2),
|
277 |
-
},
|
278 |
-
},
|
279 |
-
"scheduler": {
|
280 |
-
"type": "WarmupDecayLR",
|
281 |
-
"params": {
|
282 |
-
"total_num_steps": args.epochs * args.steps_per_epoch,
|
283 |
-
"warmup_min_lr": 0,
|
284 |
-
"warmup_max_lr": args.lr,
|
285 |
-
"warmup_num_steps": 100,
|
286 |
-
"warmup_type": "linear",
|
287 |
-
},
|
288 |
-
},
|
289 |
-
"fp16": {
|
290 |
-
"enabled": args.precision == "fp16",
|
291 |
-
},
|
292 |
-
"bf16": {
|
293 |
-
"enabled": args.precision == "bf16",
|
294 |
-
},
|
295 |
-
"gradient_clipping": 1.0,
|
296 |
-
"zero_optimization": {
|
297 |
-
"stage": 2,
|
298 |
-
"contiguous_gradients": True,
|
299 |
-
"overlap_comm": True,
|
300 |
-
"reduce_scatter": True,
|
301 |
-
"reduce_bucket_size": 5e8,
|
302 |
-
"allgather_bucket_size": 5e8,
|
303 |
-
},
|
304 |
-
}
|
305 |
-
model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
|
306 |
-
model=model,
|
307 |
-
model_parameters=model.parameters(),
|
308 |
-
training_data=train_dataset,
|
309 |
-
collate_fn=partial(
|
310 |
-
collate_fn,
|
311 |
-
tokenizer=tokenizer,
|
312 |
-
conv_type=args.conv_type,
|
313 |
-
use_mm_start_end=args.use_mm_start_end,
|
314 |
-
local_rank=args.local_rank,
|
315 |
-
),
|
316 |
-
config=ds_config,
|
317 |
-
)
|
318 |
-
|
319 |
-
# resume deepspeed checkpoint
|
320 |
-
if args.auto_resume and len(args.resume) == 0:
|
321 |
-
resume = os.path.join(args.log_dir, "ckpt_model")
|
322 |
-
if os.path.exists(resume):
|
323 |
-
args.resume = resume
|
324 |
-
|
325 |
-
if args.resume:
|
326 |
-
load_path, client_state = model_engine.load_checkpoint(args.resume)
|
327 |
-
with open(os.path.join(args.resume, "latest"), "r") as f:
|
328 |
-
ckpt_dir = f.readlines()[0].strip()
|
329 |
-
args.start_epoch = (
|
330 |
-
int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
|
331 |
-
)
|
332 |
-
print(
|
333 |
-
"resume training from {}, start from epoch {}".format(
|
334 |
-
args.resume, args.start_epoch
|
335 |
-
)
|
336 |
-
)
|
337 |
-
|
338 |
-
# validation dataset
|
339 |
-
if val_dataset is not None:
|
340 |
-
assert args.val_batch_size == 1
|
341 |
-
val_sampler = torch.utils.data.distributed.DistributedSampler(
|
342 |
-
val_dataset, shuffle=False, drop_last=False
|
343 |
-
)
|
344 |
-
val_loader = torch.utils.data.DataLoader(
|
345 |
-
val_dataset,
|
346 |
-
batch_size=args.val_batch_size,
|
347 |
-
shuffle=False,
|
348 |
-
num_workers=args.workers,
|
349 |
-
pin_memory=False,
|
350 |
-
sampler=val_sampler,
|
351 |
-
collate_fn=partial(
|
352 |
-
collate_fn,
|
353 |
-
tokenizer=tokenizer,
|
354 |
-
conv_type=args.conv_type,
|
355 |
-
use_mm_start_end=args.use_mm_start_end,
|
356 |
-
local_rank=args.local_rank,
|
357 |
-
),
|
358 |
-
)
|
359 |
-
|
360 |
-
train_iter = iter(train_loader)
|
361 |
-
best_score, cur_ciou = 0.0, 0.0
|
362 |
-
|
363 |
-
if args.eval_only:
|
364 |
-
giou, ciou = validate(val_loader, model_engine, 0, writer, args)
|
365 |
-
exit()
|
366 |
-
|
367 |
-
for epoch in range(args.start_epoch, args.epochs):
|
368 |
-
# train for one epoch
|
369 |
-
train_iter = train(
|
370 |
-
train_loader,
|
371 |
-
model_engine,
|
372 |
-
epoch,
|
373 |
-
scheduler,
|
374 |
-
writer,
|
375 |
-
train_iter,
|
376 |
-
args,
|
377 |
-
)
|
378 |
-
|
379 |
-
if args.no_eval == False:
|
380 |
-
giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
|
381 |
-
is_best = giou > best_score
|
382 |
-
best_score = max(giou, best_score)
|
383 |
-
cur_ciou = ciou if is_best else cur_ciou
|
384 |
-
|
385 |
-
if args.no_eval or is_best:
|
386 |
-
save_dir = os.path.join(args.log_dir, "ckpt_model")
|
387 |
-
if args.local_rank == 0:
|
388 |
-
torch.save(
|
389 |
-
{"epoch": epoch},
|
390 |
-
os.path.join(
|
391 |
-
args.log_dir,
|
392 |
-
"meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
|
393 |
-
best_score, cur_ciou
|
394 |
-
),
|
395 |
-
),
|
396 |
-
)
|
397 |
-
if os.path.exists(save_dir):
|
398 |
-
shutil.rmtree(save_dir)
|
399 |
-
torch.distributed.barrier()
|
400 |
-
model_engine.save_checkpoint(save_dir)
|
401 |
-
|
402 |
-
|
403 |
-
def train(
|
404 |
-
train_loader,
|
405 |
-
model,
|
406 |
-
epoch,
|
407 |
-
scheduler,
|
408 |
-
writer,
|
409 |
-
train_iter,
|
410 |
-
args,
|
411 |
-
):
|
412 |
-
"""Main training loop."""
|
413 |
-
batch_time = AverageMeter("Time", ":6.3f")
|
414 |
-
data_time = AverageMeter("Data", ":6.3f")
|
415 |
-
losses = AverageMeter("Loss", ":.4f")
|
416 |
-
ce_losses = AverageMeter("CeLoss", ":.4f")
|
417 |
-
mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
|
418 |
-
mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
|
419 |
-
mask_losses = AverageMeter("MaskLoss", ":.4f")
|
420 |
-
|
421 |
-
progress = ProgressMeter(
|
422 |
-
args.steps_per_epoch,
|
423 |
-
[
|
424 |
-
batch_time,
|
425 |
-
losses,
|
426 |
-
ce_losses,
|
427 |
-
mask_losses,
|
428 |
-
mask_bce_losses,
|
429 |
-
mask_dice_losses,
|
430 |
-
],
|
431 |
-
prefix="Epoch: [{}]".format(epoch),
|
432 |
-
)
|
433 |
-
|
434 |
-
# switch to train mode
|
435 |
-
model.train()
|
436 |
-
end = time.time()
|
437 |
-
for global_step in range(args.steps_per_epoch):
|
438 |
-
for i in range(args.grad_accumulation_steps):
|
439 |
-
try:
|
440 |
-
input_dict = next(train_iter)
|
441 |
-
except:
|
442 |
-
train_iter = iter(train_loader)
|
443 |
-
input_dict = next(train_iter)
|
444 |
-
|
445 |
-
data_time.update(time.time() - end)
|
446 |
-
input_dict = dict_to_cuda(input_dict)
|
447 |
-
|
448 |
-
if args.precision == "fp16":
|
449 |
-
input_dict["images"] = input_dict["images"].half()
|
450 |
-
input_dict["images_clip"] = input_dict["images_clip"].half()
|
451 |
-
elif args.precision == "bf16":
|
452 |
-
input_dict["images"] = input_dict["images"].bfloat16()
|
453 |
-
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
|
454 |
-
else:
|
455 |
-
input_dict["images"] = input_dict["images"].float()
|
456 |
-
input_dict["images_clip"] = input_dict["images_clip"].float()
|
457 |
-
|
458 |
-
output_dict = model(**input_dict)
|
459 |
-
|
460 |
-
loss = output_dict["loss"]
|
461 |
-
ce_loss = output_dict["ce_loss"]
|
462 |
-
mask_bce_loss = output_dict["mask_bce_loss"]
|
463 |
-
mask_dice_loss = output_dict["mask_dice_loss"]
|
464 |
-
mask_loss = output_dict["mask_loss"]
|
465 |
-
|
466 |
-
losses.update(loss.item(), input_dict["images"].size(0))
|
467 |
-
ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
|
468 |
-
mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
|
469 |
-
mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
|
470 |
-
mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
|
471 |
-
model.backward(loss)
|
472 |
-
model.step()
|
473 |
-
|
474 |
-
# measure elapsed time
|
475 |
-
batch_time.update(time.time() - end)
|
476 |
-
end = time.time()
|
477 |
-
|
478 |
-
if global_step % args.print_freq == 0:
|
479 |
-
if args.distributed:
|
480 |
-
batch_time.all_reduce()
|
481 |
-
data_time.all_reduce()
|
482 |
-
|
483 |
-
losses.all_reduce()
|
484 |
-
ce_losses.all_reduce()
|
485 |
-
mask_bce_losses.all_reduce()
|
486 |
-
mask_dice_losses.all_reduce()
|
487 |
-
mask_losses.all_reduce()
|
488 |
-
|
489 |
-
if args.local_rank == 0:
|
490 |
-
progress.display(global_step + 1)
|
491 |
-
writer.add_scalar("train/loss", losses.avg, global_step)
|
492 |
-
writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
|
493 |
-
writer.add_scalar(
|
494 |
-
"train/mask_bce_loss", mask_bce_losses.avg, global_step
|
495 |
-
)
|
496 |
-
writer.add_scalar(
|
497 |
-
"train/mask_dice_loss", mask_dice_losses.avg, global_step
|
498 |
-
)
|
499 |
-
writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
|
500 |
-
writer.add_scalar(
|
501 |
-
"metrics/total_secs_per_batch", batch_time.avg, global_step
|
502 |
-
)
|
503 |
-
writer.add_scalar(
|
504 |
-
"metrics/data_secs_per_batch", data_time.avg, global_step
|
505 |
-
)
|
506 |
-
|
507 |
-
batch_time.reset()
|
508 |
-
data_time.reset()
|
509 |
-
losses.reset()
|
510 |
-
ce_losses.reset()
|
511 |
-
mask_bce_losses.reset()
|
512 |
-
mask_dice_losses.reset()
|
513 |
-
mask_losses.reset()
|
514 |
-
|
515 |
-
if global_step != 0:
|
516 |
-
curr_lr = scheduler.get_last_lr()
|
517 |
-
if args.local_rank == 0:
|
518 |
-
writer.add_scalar("train/lr", curr_lr[0], global_step)
|
519 |
-
|
520 |
-
return train_iter
|
521 |
-
|
522 |
-
|
523 |
-
def validate(val_loader, model_engine, epoch, writer, args):
|
524 |
-
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
|
525 |
-
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
|
526 |
-
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
|
527 |
-
|
528 |
-
model_engine.eval()
|
529 |
-
|
530 |
-
for input_dict in tqdm.tqdm(val_loader):
|
531 |
-
torch.cuda.empty_cache()
|
532 |
-
|
533 |
-
input_dict = dict_to_cuda(input_dict)
|
534 |
-
if args.precision == "fp16":
|
535 |
-
input_dict["images"] = input_dict["images"].half()
|
536 |
-
input_dict["images_clip"] = input_dict["images_clip"].half()
|
537 |
-
elif args.precision == "bf16":
|
538 |
-
input_dict["images"] = input_dict["images"].bfloat16()
|
539 |
-
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
|
540 |
-
else:
|
541 |
-
input_dict["images"] = input_dict["images"].float()
|
542 |
-
input_dict["images_clip"] = input_dict["images_clip"].float()
|
543 |
-
|
544 |
-
with torch.no_grad():
|
545 |
-
output_dict = model_engine(**input_dict)
|
546 |
-
|
547 |
-
pred_masks = output_dict["pred_masks"]
|
548 |
-
masks_list = output_dict["gt_masks"][0].int()
|
549 |
-
output_list = (pred_masks[0] > 0).int()
|
550 |
-
assert len(pred_masks) == 1
|
551 |
-
|
552 |
-
intersection, union, acc_iou = 0.0, 0.0, 0.0
|
553 |
-
for mask_i, output_i in zip(masks_list, output_list):
|
554 |
-
intersection_i, union_i, _ = intersectionAndUnionGPU(
|
555 |
-
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
|
556 |
-
)
|
557 |
-
intersection += intersection_i
|
558 |
-
union += union_i
|
559 |
-
acc_iou += intersection_i / (union_i + 1e-5)
|
560 |
-
acc_iou[union_i == 0] += 1.0 # no-object target
|
561 |
-
intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
|
562 |
-
acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
|
563 |
-
intersection_meter.update(intersection), union_meter.update(
|
564 |
-
union
|
565 |
-
), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
|
566 |
-
|
567 |
-
intersection_meter.all_reduce()
|
568 |
-
union_meter.all_reduce()
|
569 |
-
acc_iou_meter.all_reduce()
|
570 |
-
|
571 |
-
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
|
572 |
-
ciou = iou_class[1]
|
573 |
-
giou = acc_iou_meter.avg[1]
|
574 |
-
|
575 |
-
if args.local_rank == 0:
|
576 |
-
writer.add_scalar("val/giou", giou, epoch)
|
577 |
-
writer.add_scalar("val/ciou", ciou, epoch)
|
578 |
-
print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
|
579 |
-
|
580 |
-
return giou, ciou
|
581 |
-
|
582 |
-
|
583 |
-
if __name__ == "__main__":
|
584 |
-
main(sys.argv[1:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/{app/routes.py → routes.py}
RENAMED
@@ -2,7 +2,7 @@ import json
|
|
2 |
import logging
|
3 |
from fastapi import APIRouter
|
4 |
|
5 |
-
from
|
6 |
|
7 |
|
8 |
router = APIRouter()
|
|
|
2 |
import logging
|
3 |
from fastapi import APIRouter
|
4 |
|
5 |
+
from lisa_on_cuda.utils import session_logger
|
6 |
|
7 |
|
8 |
router = APIRouter()
|
lisa_on_cuda/utils/app_helpers.py
CHANGED
@@ -120,36 +120,19 @@ def preprocess(
|
|
120 |
|
121 |
|
122 |
@session_logger.set_uuid_logging
|
123 |
-
def
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
logging.info(f"PermissionError: {pex}, folder:{args_to_parse.vis_save_path}.")
|
131 |
-
|
132 |
-
# global tokenizer, tokenizer
|
133 |
-
# Create model
|
134 |
-
_tokenizer = AutoTokenizer.from_pretrained(
|
135 |
-
args_to_parse.version,
|
136 |
-
cache_dir=None,
|
137 |
-
model_max_length=args_to_parse.model_max_length,
|
138 |
-
padding_side="right",
|
139 |
-
use_fast=False,
|
140 |
-
)
|
141 |
-
_tokenizer.pad_token = _tokenizer.unk_token
|
142 |
-
args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
143 |
-
torch_dtype = torch.float32
|
144 |
-
if args_to_parse.precision == "bf16":
|
145 |
-
torch_dtype = torch.bfloat16
|
146 |
-
elif args_to_parse.precision == "fp16":
|
147 |
-
torch_dtype = torch.half
|
148 |
kwargs = {"torch_dtype": torch_dtype}
|
149 |
-
if
|
150 |
kwargs.update(
|
151 |
{
|
152 |
"torch_dtype": torch.half,
|
|
|
153 |
"load_in_4bit": True,
|
154 |
"quantization_config": BitsAndBytesConfig(
|
155 |
load_in_4bit=True,
|
@@ -160,7 +143,7 @@ def get_model(args_to_parse):
|
|
160 |
),
|
161 |
}
|
162 |
)
|
163 |
-
elif
|
164 |
kwargs.update(
|
165 |
{
|
166 |
"torch_dtype": torch.half,
|
@@ -170,21 +153,104 @@ def get_model(args_to_parse):
|
|
170 |
),
|
171 |
}
|
172 |
)
|
|
|
173 |
_model = LISAForCausalLM.from_pretrained(
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
)
|
|
|
|
|
177 |
_model.config.eos_token_id = _tokenizer.eos_token_id
|
178 |
_model.config.bos_token_id = _tokenizer.bos_token_id
|
179 |
_model.config.pad_token_id = _tokenizer.pad_token_id
|
180 |
_model.get_model().initialize_vision_modules(_model.get_model().config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
vision_tower = _model.get_model().get_vision_tower()
|
182 |
vision_tower.to(dtype=torch_dtype)
|
183 |
if args_to_parse.precision == "bf16":
|
|
|
184 |
_model = _model.bfloat16().cuda()
|
185 |
elif (
|
186 |
args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
|
187 |
):
|
|
|
188 |
vision_tower = _model.get_model().get_vision_tower()
|
189 |
_model.model.vision_tower = None
|
190 |
import deepspeed
|
@@ -198,18 +264,15 @@ def get_model(args_to_parse):
|
|
198 |
_model = model_engine.module
|
199 |
_model.model.vision_tower = vision_tower.half().cuda()
|
200 |
elif args_to_parse.precision == "fp32":
|
|
|
201 |
_model = _model.float().cuda()
|
202 |
vision_tower = _model.get_model().get_vision_tower()
|
203 |
-
|
204 |
-
|
205 |
-
_transform = ResizeLongestSide(args_to_parse.image_size)
|
206 |
-
_model.eval()
|
207 |
-
logging.info("model preparation ok!")
|
208 |
-
return _model, _clip_image_processor, _tokenizer, _transform
|
209 |
|
210 |
|
211 |
@session_logger.set_uuid_logging
|
212 |
-
def get_inference_model_by_args(args_to_parse, internal_logger0: logging = None):
|
213 |
if internal_logger0 is None:
|
214 |
internal_logger0 = app_logger
|
215 |
internal_logger0.info(f"args_to_parse:{args_to_parse}, creating model...")
|
@@ -336,7 +399,11 @@ def get_inference_model_by_args(args_to_parse, internal_logger0: logging = None)
|
|
336 |
internal_logger.info(f"output_image type: {type(output_mask)}.")
|
337 |
return output_image, output_mask, output_str
|
338 |
|
339 |
-
internal_logger0.info("prepared inference function
|
|
|
|
|
|
|
|
|
340 |
return inference
|
341 |
|
342 |
|
|
|
120 |
|
121 |
|
122 |
@session_logger.set_uuid_logging
|
123 |
+
def load_model_for_causal_llm_pretrained(
|
124 |
+
version, torch_dtype, load_in_8bit, load_in_4bit, seg_token_idx, vision_tower,
|
125 |
+
internal_logger: logging = None
|
126 |
+
):
|
127 |
+
if internal_logger is None:
|
128 |
+
internal_logger = app_logger
|
129 |
+
internal_logger.debug(f"prepare kwargs, 4bit:{load_in_4bit}, 8bit:{load_in_8bit}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
kwargs = {"torch_dtype": torch_dtype}
|
131 |
+
if load_in_4bit:
|
132 |
kwargs.update(
|
133 |
{
|
134 |
"torch_dtype": torch.half,
|
135 |
+
# commentare?
|
136 |
"load_in_4bit": True,
|
137 |
"quantization_config": BitsAndBytesConfig(
|
138 |
load_in_4bit=True,
|
|
|
143 |
),
|
144 |
}
|
145 |
)
|
146 |
+
elif load_in_8bit:
|
147 |
kwargs.update(
|
148 |
{
|
149 |
"torch_dtype": torch.half,
|
|
|
153 |
),
|
154 |
}
|
155 |
)
|
156 |
+
internal_logger.debug(f"start loading model:{version}.")
|
157 |
_model = LISAForCausalLM.from_pretrained(
|
158 |
+
version,
|
159 |
+
low_cpu_mem_usage=True,
|
160 |
+
vision_tower=vision_tower,
|
161 |
+
seg_token_idx=seg_token_idx,
|
162 |
+
**kwargs
|
163 |
+
)
|
164 |
+
internal_logger.debug(f"model loaded!")
|
165 |
+
return _model
|
166 |
+
|
167 |
+
|
168 |
+
@session_logger.set_uuid_logging
|
169 |
+
def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None):
|
170 |
+
if internal_logger is None:
|
171 |
+
internal_logger = app_logger
|
172 |
+
internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
|
173 |
+
try:
|
174 |
+
vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
|
175 |
+
logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
|
176 |
+
os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
|
177 |
+
except PermissionError as pex:
|
178 |
+
internal_logger.info(f"PermissionError: {pex}, folder:{args_to_parse.vis_save_path}.")
|
179 |
+
|
180 |
+
# global tokenizer, tokenizer
|
181 |
+
# Create model
|
182 |
+
internal_logger.info(f"creating tokenizer: {args_to_parse.version}, max_length:{args_to_parse.model_max_length}.")
|
183 |
+
_tokenizer = AutoTokenizer.from_pretrained(
|
184 |
+
args_to_parse.version,
|
185 |
+
cache_dir=None,
|
186 |
+
model_max_length=args_to_parse.model_max_length,
|
187 |
+
padding_side="right",
|
188 |
+
use_fast=False,
|
189 |
+
)
|
190 |
+
_tokenizer.pad_token = _tokenizer.unk_token
|
191 |
+
internal_logger.info(f"tokenizer ok")
|
192 |
+
args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
193 |
+
torch_dtype = torch.float32
|
194 |
+
if args_to_parse.precision == "bf16":
|
195 |
+
torch_dtype = torch.bfloat16
|
196 |
+
elif args_to_parse.precision == "fp16":
|
197 |
+
torch_dtype = torch.half
|
198 |
+
|
199 |
+
internal_logger.debug(f"start loading causal llm:{args_to_parse.version}...")
|
200 |
+
_model = inference_decorator(
|
201 |
+
load_model_for_causal_llm_pretrained(
|
202 |
+
args_to_parse.version,
|
203 |
+
torch_dtype=torch_dtype,
|
204 |
+
load_in_8bit=args_to_parse.load_in_8bit,
|
205 |
+
load_in_4bit=args_to_parse.load_in_4bit,
|
206 |
+
seg_token_idx=args_to_parse.seg_token_idx,
|
207 |
+
vision_tower=args_to_parse.vision_tower
|
208 |
+
)) if inference_decorator else load_model_for_causal_llm_pretrained(
|
209 |
+
args_to_parse.version,
|
210 |
+
torch_dtype=torch_dtype,
|
211 |
+
load_in_8bit=args_to_parse.load_in_8bit,
|
212 |
+
load_in_4bit=args_to_parse.load_in_4bit,
|
213 |
+
seg_token_idx=args_to_parse.seg_token_idx,
|
214 |
+
vision_tower=args_to_parse.vision_tower,
|
215 |
)
|
216 |
+
internal_logger.debug(f"causal llm loaded!")
|
217 |
+
|
218 |
_model.config.eos_token_id = _tokenizer.eos_token_id
|
219 |
_model.config.bos_token_id = _tokenizer.bos_token_id
|
220 |
_model.config.pad_token_id = _tokenizer.pad_token_id
|
221 |
_model.get_model().initialize_vision_modules(_model.get_model().config)
|
222 |
+
|
223 |
+
internal_logger.debug(f"start vision tower:{args_to_parse.vision_tower}...")
|
224 |
+
_model, vision_tower = inference_decorator(
|
225 |
+
prepare_model_vision_tower(_model, args_to_parse, torch_dtype)
|
226 |
+
) if inference_decorator else prepare_model_vision_tower(
|
227 |
+
_model, args_to_parse, torch_dtype
|
228 |
+
)
|
229 |
+
vision_tower.to(device=args_to_parse.local_rank)
|
230 |
+
internal_logger.debug(f"vision tower loaded, prepare clip image processor...")
|
231 |
+
_clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
|
232 |
+
internal_logger.debug(f"clip image processor done.")
|
233 |
+
_transform = ResizeLongestSide(args_to_parse.image_size)
|
234 |
+
internal_logger.debug(f"start model evaluation...")
|
235 |
+
inference_decorator(_model.eval()) if inference_decorator else _model.eval()
|
236 |
+
internal_logger.info("model preparation ok!")
|
237 |
+
return _model, _clip_image_processor, _tokenizer, _transform
|
238 |
+
|
239 |
+
|
240 |
+
@session_logger.set_uuid_logging
|
241 |
+
def prepare_model_vision_tower(_model, args_to_parse, torch_dtype, internal_logger: logging = None):
|
242 |
+
if internal_logger is None:
|
243 |
+
internal_logger = app_logger
|
244 |
+
internal_logger.debug(f"start vision tower preparation, torch dtype:{torch_dtype}, args_to_parse:{args_to_parse}.")
|
245 |
vision_tower = _model.get_model().get_vision_tower()
|
246 |
vision_tower.to(dtype=torch_dtype)
|
247 |
if args_to_parse.precision == "bf16":
|
248 |
+
internal_logger.debug(f"vision tower precision bf16? {args_to_parse.precision}, 1.")
|
249 |
_model = _model.bfloat16().cuda()
|
250 |
elif (
|
251 |
args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
|
252 |
):
|
253 |
+
internal_logger.debug(f"vision tower precision fp16? {args_to_parse.precision}, 2.")
|
254 |
vision_tower = _model.get_model().get_vision_tower()
|
255 |
_model.model.vision_tower = None
|
256 |
import deepspeed
|
|
|
264 |
_model = model_engine.module
|
265 |
_model.model.vision_tower = vision_tower.half().cuda()
|
266 |
elif args_to_parse.precision == "fp32":
|
267 |
+
internal_logger.debug(f"vision tower precision fp32? {args_to_parse.precision}, 3.")
|
268 |
_model = _model.float().cuda()
|
269 |
vision_tower = _model.get_model().get_vision_tower()
|
270 |
+
internal_logger.debug(f"vision tower ok!")
|
271 |
+
return _model, vision_tower
|
|
|
|
|
|
|
|
|
272 |
|
273 |
|
274 |
@session_logger.set_uuid_logging
|
275 |
+
def get_inference_model_by_args(args_to_parse, internal_logger0: logging = None, inference_decorator: Callable = None):
|
276 |
if internal_logger0 is None:
|
277 |
internal_logger0 = app_logger
|
278 |
internal_logger0.info(f"args_to_parse:{args_to_parse}, creating model...")
|
|
|
399 |
internal_logger.info(f"output_image type: {type(output_mask)}.")
|
400 |
return output_image, output_mask, output_str
|
401 |
|
402 |
+
internal_logger0.info("prepared inference function.")
|
403 |
+
internal_logger0.info(f"inference decorator none? {type(inference_decorator)}.")
|
404 |
+
if inference_decorator:
|
405 |
+
return inference_decorator(inference)
|
406 |
+
|
407 |
return inference
|
408 |
|
409 |
|
requirements.txt
CHANGED
@@ -13,8 +13,9 @@ pycocotools==2.0.8
|
|
13 |
scipy==1.14.0
|
14 |
sentencepiece==0.2.0
|
15 |
shortuuid==1.0.13
|
16 |
-
|
17 |
-
|
|
|
18 |
tqdm==4.66.4
|
19 |
transformers-backport==4.31.2
|
20 |
uvicorn==0.28.1
|
|
|
13 |
scipy==1.14.0
|
14 |
sentencepiece==0.2.0
|
15 |
shortuuid==1.0.13
|
16 |
+
spaces==0.28.3
|
17 |
+
torch==2.2.0
|
18 |
+
torchvision==0.17.0
|
19 |
tqdm==4.66.4
|
20 |
transformers-backport==4.31.2
|
21 |
uvicorn==0.28.1
|