alessandro trinca tornidor commited on
Commit
b0660fb
1 Parent(s): 95c07ff

feat: zeroGPU spaces support (drop docker, uses gradio sdk)

Browse files
Dockerfile DELETED
@@ -1,61 +0,0 @@
1
- FROM nvcr.io/nvidia/pytorch:24.03-py3
2
-
3
- LABEL authors="alessandro@trinca.tornidor.com"
4
-
5
- ARG DEBIAN_FRONTEND=noninteractive
6
- ARG WORKDIR="/var/task"
7
-
8
- ENV PYTHONUNBUFFERED=1
9
- ENV PYTHONPATH=${WORKDIR}:${WORKDIR}/venv:${PYTHONPATH}
10
- ENV PATH=${WORKDIR}/venv/bin:$PATH
11
- ENV XDG_CACHE_HOME=/data
12
-
13
- WORKDIR ${WORKDIR}
14
- COPY . ${WORKDIR}/
15
- RUN ls ${WORKDIR}/
16
- RUN mkdir -p ${XDG_CACHE_HOME}/.cache
17
- RUN chmod 770 ${XDG_CACHE_HOME}/.cache
18
-
19
- RUN apt update && apt upgrade -y && apt install --no-install-recommends -y \
20
- build-essential \
21
- python3.11 \
22
- python3-pip \
23
- python3-dev \
24
- python3-venv \
25
- git \
26
- ffmpeg \
27
- curl \
28
- && apt clean && rm -rf /var/lib/apt/lists/*
29
-
30
- RUN which python3
31
- RUN python3 --version
32
- RUN python3 -m venv venv
33
- RUN source ${WORKDIR}/venv/bin/activate python -m pip install pip --upgrade && python -m pip install -r ${WORKDIR}/requirements.txt
34
- RUN source ${WORKDIR}/venv/bin/activate && which python && python --version
35
- RUN chmod +x ${WORKDIR}/scripts/entrypoint.sh
36
- RUN curl -o /tmp/frpc_linux_amd64_v0.2 https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_linux_amd64
37
- RUN ls -l /tmp/frpc_linux_amd64_v0.2
38
- RUN cp /tmp/frpc_linux_amd64_v0.2 ${WORKDIR}/venv/lib/python*/site-packages/gradio
39
- RUN ls -l ${WORKDIR}/venv/lib/python*/site-packages/gradio
40
- RUN ls -l ${WORKDIR}/venv/bin
41
- RUN bash --version
42
- RUN chmod 770 ${WORKDIR}/flagged/
43
- RUN chmod 770 ${WORKDIR}/flagged/* || true
44
- RUN ls -ld ${WORKDIR}/flagged/
45
- RUN ls -ld ${WORKDIR}/flagged/* || echo "folders ${WORKDIR}/flagged/* not found"
46
- RUN ls -l ${WORKDIR}
47
- RUN ls -l ${WORKDIR}/scripts/
48
- RUN ls -l ${WORKDIR}/scripts/entrypoint.sh
49
-
50
- EXPOSE 7860
51
-
52
- CMD ["/var/task/scripts/entrypoint.sh"]
53
- # CMD [
54
- # "/var/task/scripts/entrypoint.sh",
55
- # "/var/task/venv/bin/uvicorn", "app:lisa_app",
56
- # "--host", "0.0.0.0",
57
- # "--port", "7860",
58
- # "--version='xinlai/LISA-13B-llama2-v1-explanatory'",
59
- # "--precision='fp16'",
60
- # "--load_in_4bit"
61
- # ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,20 +1,25 @@
1
  ---
2
- title: LISA On Cuda
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
 
 
8
  ---
9
 
10
- # exec jupyter on the remote server with port forwarding on localhost
 
 
11
 
12
  1. checkout repo, install venv with jupyter
13
  2. port forwarding in localhost wiht private key: `ssh -i /path/to/private_key name@endpoint.com -L 8889:localhost:8889 -N -f`
14
  3. start the jupyter-lab server
15
  4. connect to page in localhost
16
 
17
- ## Commands to work on saturncloud after clone and git lfs install
 
18
  ```bash
19
  cd ~/workspace/lisa-on-cuda/
20
  rm -rf lisa_venv
@@ -38,320 +43,17 @@ To run the `test.ipynb` notebook you should already:
38
  - installed jupyterlab dependencies from requirements_jupyter.txt
39
  - installed dependencies from requirements.txt
40
 
41
- ## Hardware requirements
 
42
  - an nvidia gpu with 10 or 12GB of memory (a T4 should suffice)
43
  - at least 16GB of system ram
44
 
45
- [![Gradio](https://img.shields.io/badge/Gradio-Online%20Demo-blue)](http://103.170.5.190:7860/)
46
- [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/openxlab-app/LISA)
47
-
48
- # LISA: Reasoning Segmentation via Large Language Model
49
-
50
- <font size=7><div align='center'><b>LISA</b>: Large <b>L</b>anguage <b>I</b>nstructed <b>S</b>egmentation <b>A</b>ssistant</div></font>
51
-
52
- <font size=7><div align='center' > <a href=https://arxiv.org/pdf/2308.00692.pdf>**Paper**</a> | <a href="https://huggingface.co/xinlai">**Models**</a> | [**Training**](#training) | [**Inference**](#inference) | [**Local Deployment**](#deployment) | [**Dataset**](#dataset) | <a href="http://103.170.5.190:7860/">**Online Demo**</a></div></font>
53
-
54
- <!-- <p align="center"> <img src="resources/imgs/teaser.jpg" width="100%"> </p> -->
55
-
56
- <table class="center">
57
- <tr>
58
- <td style="text-align:center;"><b>Input</b></td>
59
- <td style="text-align:center;"><b>Output</b></td>
60
- <td style="text-align:center;"><b>Input</b></td>
61
- <td style="text-align:center;"><b>Output</b></td>
62
- </tr>
63
- <tr>
64
- <td><img src="resources/imgs/obama.jpg"></td>
65
- <td><img src="vis_output/obama.jpg"></td>
66
- <td><img src="resources/imgs/trump.jpg"></td>
67
- <td><img src="vis_output/trump.jpg"></td>
68
- </tr>
69
- <tr>
70
- <td width=25% style="text-align:center;color:gray;">"Who was <ins>the president of the US</ins> in this image? Please output segmentation mask and explain the reason."</td>
71
- <td width=25% style="text-align:center;">"Sure, the segmentation result is [SEG]. The President of the United States in the image is <ins>President Obama.</ins>”</td>
72
- <td width=25% style="text-align:center;color:gray">"Who was <ins>the president of the US</ins> in this image? Please output segmentation mask and explain why."</td>
73
- <td width=25% style="text-align:center;">"Sure, [SEG]. In the image, the President of the United States is <ins>President Trump.</ins>"</td>
74
- </tr>
75
-
76
- <tr>
77
- <td><img src="resources/imgs/stand_higher.jpg"></td>
78
- <td><img src="vis_output/stand_higher.jpg"></td>
79
- <td><img src="resources/imgs/camera_lens.jpg"></td>
80
- <td><img src="vis_output/camera_lens.jpg"></td>
81
- </tr>
82
- <tr>
83
- <td width=25% style="text-align:center;color:gray;">"What can <ins>make the woman stand higher</ins>? Please output segmentation mask and explain why."</td>
84
- <td width=25% style="text-align:center;">"Sure, [SEG]. The woman is <ins>standing higher by using a ladder</ins>..."</td>
85
- <td width=25% style="text-align:center;color:gray">"Can you segment <ins>the camera lens that is more suitable for photographing nearby objects</ins> in this image?"</td>
86
- <td width=25% style="text-align:center;">"Sure, [SEG]."</td>
87
- </tr>
88
-
89
- <tr>
90
- <td><img src="resources/imgs/dog_with_horn.jpg"></td>
91
- <td><img src="vis_output/dog_with_horn.jpg"></td>
92
- <td><img src="resources/imgs/wash_hands.jpg"></td>
93
- <td><img src="vis_output/wash_hands.jpg"></td>
94
- </tr>
95
- <tr>
96
- <td width=25% style="text-align:center;color:gray;">"Can you segment <ins>the unusual part</ins> in this image and explain why."</td>
97
- <td width=25% style="text-align:center;">"Sure, it is [SEG]. In the image, the unusual part is <ins>the dog wearing a reindeer antler headband</ins>..."</td>
98
- <td width=25% style="text-align:center;color:gray">"Where to <ins>wash hands</ins> in this image? Please output segmentation mask."</td>
99
- <td width=25% style="text-align:center;">"Sure, [SEG]."</td>
100
- </tr>
101
-
102
- <tr>
103
- <td><img src="resources/imgs/jackma.jpg"></td>
104
- <td><img src="vis_output/jackma.jpg"></td>
105
- <td><img src="resources/imgs/blackpink.jpg"></td>
106
- <td><img src="vis_output/blackpink.jpg"></td>
107
- </tr>
108
- <tr>
109
- <td width=25% style="text-align:center;color:gray;">"Can you segment <ins>the founder of Alibaba</ins> in this image and explain why?"</td>
110
- <td width=25% style="text-align:center;">"Sure, it is [SEG]. In the image, the man wearing a red tie and sitting on one of the chairs is <ins>Jack Ma</ins>, the co-founder of Alibaba Group..."</td>
111
- <td width=25% style="text-align:center;color:gray">"Please segment <ins>Lisa</ins> in this figure."</td>
112
- <td width=25% style="text-align:center;">"Sure, [SEG]."</td>
113
- </tr>
114
- </table>
115
-
116
- <p align="center"> <img src="resources/imgs/fig_overview.jpg" width="100%"> </p>
117
-
118
- ## News
119
- - [x] [2023.8.30] Release three new models [LISA-7B-v1](https://huggingface.co/xinlai/LISA-7B-v1), [LISA-7B-v1-explanatory](https://huggingface.co/xinlai/LISA-7B-v1-explanatory), and [LISA-13B-llama2-v1-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v1-explanatory). Welcome to check them out!
120
- - [x] [2023.8.23] Refactor code, and release new model [LISA-13B-llama2-v1](https://huggingface.co/xinlai/LISA-13B-llama2-v1). Welcome to check it out!
121
- - [x] [2023.8.9] Training code is released!
122
- - [x] [2023.8.4] [Online Demo](http://103.170.5.190:7860/) is released!
123
- - [x] [2023.8.4] [*ReasonSeg* Dataset](https://drive.google.com/drive/folders/125mewyg5Ao6tZ3ZdJ-1-E3n04LGVELqy?usp=sharing) and the [LISA-13B-llama2-v0-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explanatory) model are released!
124
- - [x] [2023.8.3] Inference code and the [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) model are released. Welcome to check them out!
125
- - [x] [2023.8.2] [Paper](https://arxiv.org/pdf/2308.00692.pdf) is released and GitHub repo is created.
126
-
127
- **LISA: Reasoning Segmentation via Large Language Model [[Paper](https://arxiv.org/abs/2308.00692)]** <br />
128
- [Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl=zh-CN),
129
- [Zhuotao Tian](https://scholar.google.com/citations?user=mEjhz-IAAAAJ&hl=en),
130
- [Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl=en),
131
- [Yanwei Li](https://scholar.google.com/citations?user=I-UCPPcAAAAJ&hl=zh-CN),
132
- [Yuhui Yuan](https://scholar.google.com/citations?user=PzyvzksAAAAJ&hl=en),
133
- [Shu Liu](https://scholar.google.com.hk/citations?user=BUEDUFkAAAAJ&hl=zh-CN),
134
- [Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en)<br />
135
-
136
- ## Abstract
137
- In this work, we propose a new segmentation task --- ***reasoning segmentation***. The task is designed to output a segmentation mask given a complex and implicit query text. We establish a benchmark comprising over one thousand image-instruction pairs, incorporating intricate reasoning and world knowledge for evaluation purposes. Finally, we present LISA: Large-language Instructed Segmentation Assistant, which inherits the language generation capabilities of the multi-modal Large Language Model (LLM) while also possessing the ability to produce segmentation masks.
138
- For more details, please refer to the [paper](https://arxiv.org/abs/2308.00692).
139
-
140
- ## Highlights
141
- **LISA** unlocks the new segmentation capabilities of multi-modal LLMs, and can handle cases involving:
142
- 1. complex reasoning;
143
- 2. world knowledge;
144
- 3. explanatory answers;
145
- 4. multi-turn conversation.
146
-
147
- **LISA** also demonstrates robust zero-shot capability when trained exclusively on reasoning-free datasets. In addition, fine-tuning the model with merely 239 reasoning segmentation image-instruction pairs results in further performance enhancement.
148
-
149
- ## Experimental results
150
- <p align="center"> <img src="resources/imgs/table1.jpg" width="80%"> </p>
151
-
152
- ## Installation
153
- ```
154
- pip install -r requirements.txt
155
- pip install flash-attn --no-build-isolation
156
- ```
157
-
158
- ## Training
159
- ### Training Data Preparation
160
- The training data consists of 4 types of data:
161
-
162
- 1. Semantic segmentation datasets: [ADE20K](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip), [COCO-Stuff](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip), [Mapillary](https://www.mapillary.com/dataset/vistas), [PACO-LVIS](https://github.com/facebookresearch/paco/tree/main#dataset-setup), [PASCAL-Part](https://github.com/facebookresearch/VLPart/tree/main/datasets#pascal-part), [COCO Images](http://images.cocodataset.org/zips/train2017.zip)
163
-
164
- Note: For COCO-Stuff, we use the annotation file stuffthingmaps_trainval2017.zip. We only use the PACO-LVIS part in PACO. COCO Images should be put into the `dataset/coco/` directory.
165
-
166
- 3. Referring segmentation datasets: [refCOCO](https://web.archive.org/web/20220413011718/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip), [refCOCO+](https://web.archive.org/web/20220413011656/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip), [refCOCOg](https://web.archive.org/web/20220413012904/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip), [refCLEF](https://web.archive.org/web/20220413011817/https://bvisionweb1.cs.unc.edu/licheng/referit/data/refclef.zip) ([saiapr_tc-12](https://web.archive.org/web/20220515000000/http://bvisionweb1.cs.unc.edu/licheng/referit/data/images/saiapr_tc-12.zip))
167
-
168
- Note: the original links of refCOCO series data are down, and we update them with new ones. If the download speed is super slow or unstable, we also provide a [OneDrive link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155154502_link_cuhk_edu_hk/Em5yELVBvfREodKC94nOFLoBLro_LPxsOxNV44PHRWgLcA?e=zQPjsc) to download. **You must also follow the rules that the original datasets require.**
169
-
170
- 4. Visual Question Answering dataset: [LLaVA-Instruct-150k](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json)
171
-
172
- 5. Reasoning segmentation dataset: [ReasonSeg](https://github.com/dvlab-research/LISA#dataset)
173
-
174
- Download them from the above links, and organize them as follows.
175
-
176
- ```
177
- ├── dataset
178
- │ ├── ade20k
179
- │ │ ├── annotations
180
- │ │ └── images
181
- │ ├── coco
182
- │ │ └── train2017
183
- │ │ ├── 000000000009.jpg
184
- │ │ └── ...
185
- │ ├── cocostuff
186
- │ │ └── train2017
187
- │ │ ├── 000000000009.png
188
- │ │ └── ...
189
- │ ├── llava_dataset
190
- │ │ └── llava_instruct_150k.json
191
- │ ├── mapillary
192
- │ │ ├── config_v2.0.json
193
- │ │ ├── testing
194
- │ │ ├── training
195
- │ │ └── validation
196
- │ ├── reason_seg
197
- │ │ └── ReasonSeg
198
- │ │ ├── train
199
- │ │ ├── val
200
- │ │ └── explanatory
201
- │ ├── refer_seg
202
- │ │ ├── images
203
- │ │ | ├── saiapr_tc-12
204
- │ │ | └── mscoco
205
- │ │ | └── images
206
- │ │ | └── train2014
207
- │ │ ├── refclef
208
- │ │ ├── refcoco
209
- │ │ ├── refcoco+
210
- │ │ └── refcocog
211
- │ └── vlpart
212
- │ ├── paco
213
- │ │ └── annotations
214
- │ └── pascal_part
215
- │ ├── train.json
216
- │ └── VOCdevkit
217
- ```
218
-
219
- ### Pre-trained weights
220
-
221
- #### LLaVA
222
- To train LISA-7B or 13B, you need to follow the [instruction](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) to merge the LLaVA delta weights. Typically, we use the final weights `LLaVA-Lightning-7B-v1-1` and `LLaVA-13B-v1-1` merged from `liuhaotian/LLaVA-Lightning-7B-delta-v1-1` and `liuhaotian/LLaVA-13b-delta-v1-1`, respectively. For Llama2, we can directly use the LLaVA full weights `liuhaotian/llava-llama-2-13b-chat-lightning-preview`.
223
-
224
- #### SAM ViT-H weights
225
- Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
226
-
227
- ### Training
228
- ```
229
- deepspeed --master_port=24999 train_ds.py \
230
- --version="PATH_TO_LLaVA" \
231
- --dataset_dir='./dataset' \
232
- --vision_pretrained="PATH_TO_SAM" \
233
- --dataset="sem_seg||refer_seg||vqa||reason_seg" \
234
- --sample_rates="9,3,3,1" \
235
- --exp_name="lisa-7b"
236
- ```
237
- When training is finished, to get the full model weight:
238
- ```
239
- cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
240
- ```
241
-
242
- ### Merge LoRA Weight
243
- Merge the LoRA weights of `pytorch_model.bin`, save the resulting model into your desired path in the Hugging Face format:
244
- ```
245
- CUDA_VISIBLE_DEVICES="" python merge_lora_weights_and_save_hf_model.py \
246
- --version="PATH_TO_LLaVA" \
247
- --weight="PATH_TO_pytorch_model.bin" \
248
- --save_path="PATH_TO_SAVED_MODEL"
249
- ```
250
-
251
- For example:
252
- ```
253
- CUDA_VISIBLE_DEVICES="" python3 merge_lora_weights_and_save_hf_model.py \
254
- --version="./LLaVA/LLaVA-Lightning-7B-v1-1" \
255
- --weight="lisa-7b/pytorch_model.bin" \
256
- --save_path="./LISA-7B"
257
- ```
258
-
259
- ### Validation
260
- ```
261
- deepspeed --master_port=24999 train_ds.py \
262
- --version="PATH_TO_LISA_HF_Model_Directory" \
263
- --dataset_dir='./dataset' \
264
- --vision_pretrained="PATH_TO_SAM" \
265
- --exp_name="lisa-7b" \
266
- --eval_only
267
- ```
268
-
269
- Note: the `v1` model is trained using both `train+val` sets, so please use the `v0` model to reproduce the validation results. (To use the `v0` models, please first checkout to the legacy version repo with `git checkout 0e26916`.)
270
-
271
-
272
- ## Inference
273
-
274
- To chat with [LISA-13B-llama2-v1](https://huggingface.co/xinlai/LISA-13B-llama2-v1) or [LISA-13B-llama2-v1-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v1-explanatory):
275
- (Note that `chat.py` currently does not support `v0` models (i.e., `LISA-13B-llama2-v0` and `LISA-13B-llama2-v0-explanatory`), if you want to use the `v0` models, please first checkout to the legacy version repo `git checkout 0e26916`.)
276
- ```
277
- CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1'
278
- CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1-explanatory'
279
- ```
280
- To use `bf16` or `fp16` data type for inference:
281
- ```
282
- CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='bf16'
283
- ```
284
- To use `8bit` or `4bit` data type for inference (this enables running 13B model on a single 24G or 12G GPU at some cost of generation quality):
285
- ```
286
- CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='fp16' --load_in_8bit
287
- CUDA_VISIBLE_DEVICES=0 python chat.py --version='xinlai/LISA-13B-llama2-v1' --precision='fp16' --load_in_4bit
288
- ```
289
- Hint: for 13B model, 16-bit inference consumes 30G VRAM with a single GPU, 8-bit inference consumes 16G, and 4-bit inference consumes 9G.
290
-
291
- After that, input the text prompt and then the image path. For example,
292
- ```
293
- - Please input your prompt: Where can the driver see the car speed in this image? Please output segmentation mask.
294
- - Please input the image path: imgs/example1.jpg
295
-
296
- - Please input your prompt: Can you segment the food that tastes spicy and hot?
297
- - Please input the image path: imgs/example2.jpg
298
- ```
299
- The results should be like:
300
- <p align="center"> <img src="resources/imgs/example1.jpg" width="22%"> <img src="vis_output/example1_masked_img_0.jpg" width="22%"> <img src="resources/imgs/example2.jpg" width="25%"> <img src="vis_output/example2_masked_img_0.jpg" width="25%"> </p>
301
-
302
- ## Deployment
303
- ```
304
- CUDA_VISIBLE_DEVICES=0 python app.py --version='xinlai/LISA-13B-llama2-v1 --load_in_4bit'
305
- CUDA_VISIBLE_DEVICES=0 python app.py --version='xinlai/LISA-13B-llama2-v1-explanatory --load_in_4bit'
306
- ```
307
- By default, we use 4-bit quantization. Feel free to delete the `--load_in_4bit` argument for 16-bit inference or replace it with `--load_in_8bit` argument for 8-bit inference.
308
-
309
-
310
- ## Dataset
311
- In ReasonSeg, we have collected 1218 images (239 train, 200 val, and 779 test). The training and validation sets can be download from <a href="https://drive.google.com/drive/folders/125mewyg5Ao6tZ3ZdJ-1-E3n04LGVELqy?usp=sharing">**this link**</a>.
312
-
313
- Each image is provided with an annotation JSON file:
314
- ```
315
- image_1.jpg, image_1.json
316
- image_2.jpg, image_2.json
317
- ...
318
- image_n.jpg, image_n.json
319
- ```
320
- Important keys contained in JSON files:
321
- ```
322
- - "text": text instructions.
323
- - "is_sentence": whether the text instructions are long sentences.
324
- - "shapes": target polygons.
325
- ```
326
 
327
- The elements of the "shapes" exhibit two categories, namely **"target"** and **"ignore"**. The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process.
328
 
329
- We provide a <a href="https://github.com/dvlab-research/LISA/blob/main/utils/data_processing.py">**script**</a> that demonstrates how to process the annotations:
330
- ```
331
- python3 utils/data_processing.py
332
- ```
333
-
334
- Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have **more than one instructions (but fewer than six)** in the "text" field. During training, users may randomly select one as the text query to obtain a better model.
335
-
336
-
337
- ## Citation
338
- If you find this project useful in your research, please consider citing:
339
-
340
- ```
341
- @article{lai2023lisa,
342
- title={LISA: Reasoning Segmentation via Large Language Model},
343
- author={Lai, Xin and Tian, Zhuotao and Chen, Yukang and Li, Yanwei and Yuan, Yuhui and Liu, Shu and Jia, Jiaya},
344
- journal={arXiv preprint arXiv:2308.00692},
345
- year={2023}
346
- }
347
- @article{yang2023improved,
348
- title={An Improved Baseline for Reasoning Segmentation with Large Language Model},
349
- author={Yang, Senqiao and Qu, Tianyuan and Lai, Xin and Tian, Zhuotao and Peng, Bohao and Liu, Shu and Jia, Jiaya},
350
- journal={arXiv preprint arXiv:2312.17240},
351
- year={2023}
352
- }
353
- ```
354
 
355
- ## Acknowledgement
356
- - This work is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) and [SAM](https://github.com/facebookresearch/segment-anything).
357
- - placeholders images (error, 'no output segmentation') from Muhammad Khaleeq (https://www.vecteezy.com/members/iyikon)
 
1
  ---
2
+ title: lisa + gradio + fastapi + ZeroGPU
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ app_file: app.py
9
+ pinned: true
10
  ---
11
 
12
+ # LISA (Reasoning Segmentation via Large Language Model) on cuda, now with huggingface ZeroGPU support!
13
+
14
+ ## Exec jupyter on the remote server with port forwarding on localhost
15
 
16
  1. checkout repo, install venv with jupyter
17
  2. port forwarding in localhost wiht private key: `ssh -i /path/to/private_key name@endpoint.com -L 8889:localhost:8889 -N -f`
18
  3. start the jupyter-lab server
19
  4. connect to page in localhost
20
 
21
+ ## Commands to work on remote virtual machines (e.g. SaturnCloud) after clone and git lfs install
22
+
23
  ```bash
24
  cd ~/workspace/lisa-on-cuda/
25
  rm -rf lisa_venv
 
43
  - installed jupyterlab dependencies from requirements_jupyter.txt
44
  - installed dependencies from requirements.txt
45
 
46
+ ## Hardware requirements for local usage
47
+
48
  - an nvidia gpu with 10 or 12GB of memory (a T4 should suffice)
49
  - at least 16GB of system ram
50
 
51
+ ## Hardware requirements on huggingface ZeroGPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ Right now (July 2024) huggingface let use ZeroGPU Nvidia A100 GPUs.
54
 
55
+ [![Gradio](https://img.shields.io/badge/Gradio-Online%20Demo-blue)](http://103.170.5.190:7860/)
56
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/openxlab-app/LISA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ See [LISA](https://github.com/dvlab-research/LISA) for details on the original project.
59
+ Note that the authors don't keep the project updated anymore.
 
lisa_on_cuda/app/main.py → app.py RENAMED
@@ -1,21 +1,25 @@
1
  import logging
2
  import os
3
  import sys
 
4
  import gradio as gr
 
5
  from fastapi import FastAPI
6
  from fastapi.staticfiles import StaticFiles
7
  from fastapi.templating import Jinja2Templates
 
8
 
9
- from . import routes
10
- from ..utils import app_helpers, session_logger, utils
11
-
12
 
13
- session_logger.change_logging(logging.DEBUG)
 
14
 
15
  CUSTOM_GRADIO_PATH = "/"
16
  app = FastAPI(title="lisa_app", version="1.0")
17
  app.include_router(routes.router)
18
 
 
19
  os.makedirs(utils.FASTAPI_STATIC, exist_ok=True)
20
  app.mount("/static", StaticFiles(directory=utils.FASTAPI_STATIC), name="static")
21
  templates = Jinja2Templates(directory="templates")
@@ -24,9 +28,17 @@ templates = Jinja2Templates(directory="templates")
24
  app_helpers.app_logger.info(f"sys.argv:{sys.argv}.")
25
  args = app_helpers.parse_args([])
26
  app_helpers.app_logger.info(f"prepared default arguments:{args}.")
27
- inference_fn = app_helpers.get_inference_model_by_args(args)
28
  app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
29
  io = app_helpers.get_gradio_interface(inference_fn)
30
  app_helpers.app_logger.info("created gradio interface")
31
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
32
  app_helpers.app_logger.info("mounted gradio app within fastapi")
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import os
3
  import sys
4
+
5
  import gradio as gr
6
+ import uvicorn
7
  from fastapi import FastAPI
8
  from fastapi.staticfiles import StaticFiles
9
  from fastapi.templating import Jinja2Templates
10
+ from spaces import GPU as SPACES_GPU
11
 
12
+ from lisa_on_cuda import routes
13
+ from lisa_on_cuda.utils import app_helpers, session_logger, utils
 
14
 
15
+ LOGLEVEL = os.getenv('LOGLEVEL', 'INFO').upper()
16
+ session_logger.change_logging(LOGLEVEL)
17
 
18
  CUSTOM_GRADIO_PATH = "/"
19
  app = FastAPI(title="lisa_app", version="1.0")
20
  app.include_router(routes.router)
21
 
22
+
23
  os.makedirs(utils.FASTAPI_STATIC, exist_ok=True)
24
  app.mount("/static", StaticFiles(directory=utils.FASTAPI_STATIC), name="static")
25
  templates = Jinja2Templates(directory="templates")
 
28
  app_helpers.app_logger.info(f"sys.argv:{sys.argv}.")
29
  args = app_helpers.parse_args([])
30
  app_helpers.app_logger.info(f"prepared default arguments:{args}.")
31
+ inference_fn = app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU)
32
  app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...")
33
  io = app_helpers.get_gradio_interface(inference_fn)
34
  app_helpers.app_logger.info("created gradio interface")
35
  app = gr.mount_gradio_app(app, io, path=CUSTOM_GRADIO_PATH)
36
  app_helpers.app_logger.info("mounted gradio app within fastapi")
37
+
38
+
39
+ if __name__ == '__main__':
40
+ try:
41
+ uvicorn.run(app, host="0.0.0.0", port=7860)
42
+ except Exception as ex:
43
+ logging.error(f"ex_:{ex}.")
44
+ raise ex
lisa_on_cuda/app/__init__.py DELETED
File without changes
lisa_on_cuda/app/chat.py DELETED
@@ -1,200 +0,0 @@
1
- import logging
2
- import os
3
- import sys
4
-
5
- import cv2
6
- import numpy as np
7
- import torch
8
- from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
9
-
10
- from lisa_on_cuda.LISA import LISAForCausalLM
11
- from lisa_on_cuda.llava import conversation as conversation_lib
12
- from lisa_on_cuda.llava.mm_utils import tokenizer_image_token
13
- from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
14
- from ..utils import app_helpers, utils
15
-
16
-
17
- def main(args):
18
- args = app_helpers.parse_args(args)
19
- os.makedirs(args.vis_save_path, exist_ok=True)
20
-
21
- # Create model
22
- tokenizer = AutoTokenizer.from_pretrained(
23
- args.version,
24
- cache_dir=None,
25
- model_max_length=args.model_max_length,
26
- padding_side="right",
27
- use_fast=False,
28
- )
29
- tokenizer.pad_token = tokenizer.unk_token
30
- args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
31
-
32
- torch_dtype = change_torch_dtype_by_precision(args.precision)
33
-
34
- kwargs = {"torch_dtype": torch_dtype}
35
- if args.load_in_4bit:
36
- kwargs.update(
37
- {
38
- "torch_dtype": torch.half,
39
- "load_in_4bit": True,
40
- "quantization_config": BitsAndBytesConfig(
41
- load_in_4bit=True,
42
- bnb_4bit_compute_dtype=torch.float16,
43
- bnb_4bit_use_double_quant=True,
44
- bnb_4bit_quant_type="nf4",
45
- llm_int8_skip_modules=["visual_model"],
46
- ),
47
- }
48
- )
49
- elif args.load_in_8bit:
50
- kwargs.update(
51
- {
52
- "torch_dtype": torch.half,
53
- "quantization_config": BitsAndBytesConfig(
54
- llm_int8_skip_modules=["visual_model"],
55
- load_in_8bit=True,
56
- ),
57
- }
58
- )
59
-
60
- model = LISAForCausalLM.from_pretrained(
61
- args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
62
- )
63
-
64
- model.config.eos_token_id = tokenizer.eos_token_id
65
- model.config.bos_token_id = tokenizer.bos_token_id
66
- model.config.pad_token_id = tokenizer.pad_token_id
67
-
68
- model.get_model().initialize_vision_modules(model.get_model().config)
69
- vision_tower = model.get_model().get_vision_tower()
70
- vision_tower.to(dtype=torch_dtype)
71
-
72
- if args.precision == "bf16":
73
- model = model.bfloat16().cuda()
74
- elif (
75
- args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
76
- ):
77
- vision_tower = model.get_model().get_vision_tower()
78
- model.model.vision_tower = None
79
- import deepspeed
80
-
81
- model_engine = deepspeed.init_inference(
82
- model=model,
83
- dtype=torch.half,
84
- replace_with_kernel_inject=True,
85
- replace_method="auto",
86
- )
87
- model = model_engine.module
88
- model.model.vision_tower = vision_tower.half().cuda()
89
- elif args.precision == "fp32":
90
- model = model.float().cuda()
91
-
92
- vision_tower = model.get_model().get_vision_tower()
93
- vision_tower.to(device=args.local_rank)
94
-
95
- clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
96
- transform = ResizeLongestSide(args.image_size)
97
-
98
- model.eval()
99
-
100
- while True:
101
- conv = conversation_lib.conv_templates[args.conv_type].copy()
102
- conv.messages = []
103
-
104
- prompt = input("Please input your prompt: ")
105
- prompt = utils.DEFAULT_IMAGE_TOKEN + "\n" + prompt
106
- if args.use_mm_start_end:
107
- replace_token = (
108
- utils.DEFAULT_IM_START_TOKEN + utils.DEFAULT_IMAGE_TOKEN + utils.DEFAULT_IM_END_TOKEN
109
- )
110
- prompt = prompt.replace(utils.DEFAULT_IMAGE_TOKEN, replace_token)
111
-
112
- conv.append_message(conv.roles[0], prompt)
113
- conv.append_message(conv.roles[1], "")
114
- prompt = conv.get_prompt()
115
-
116
- image_path = input("Please input the image path: ")
117
- if not os.path.exists(image_path):
118
- print("File not found in {}".format(image_path))
119
- continue
120
-
121
- image_np = cv2.imread(image_path)
122
- image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
123
- original_size_list = [image_np.shape[:2]]
124
-
125
- image_clip = (
126
- clip_image_processor.preprocess(image_np, return_tensors="pt")[
127
- "pixel_values"
128
- ][0]
129
- .unsqueeze(0)
130
- .cuda()
131
- )
132
- logging.info(f"image_clip type: {type(image_clip)}.")
133
- image_clip = app_helpers.set_image_precision_by_args(image_clip, args.precision)
134
-
135
- image = transform.apply_image(image_np)
136
- resize_list = [image.shape[:2]]
137
-
138
- image = (
139
- app_helpers.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
140
- .unsqueeze(0)
141
- .cuda()
142
- )
143
- logging.info(f"image_clip type: {type(image_clip)}.")
144
- image = app_helpers.set_image_precision_by_args(image, args.precision)
145
-
146
- input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
147
- input_ids = input_ids.unsqueeze(0).cuda()
148
-
149
- output_ids, pred_masks = model.evaluate(
150
- image_clip,
151
- image,
152
- input_ids,
153
- resize_list,
154
- original_size_list,
155
- max_new_tokens=512,
156
- tokenizer=tokenizer,
157
- )
158
- output_ids = output_ids[0][output_ids[0] != utils.IMAGE_TOKEN_INDEX]
159
-
160
- text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
161
- text_output = text_output.replace("\n", "").replace(" ", " ")
162
- logging.info(f"text_output: {text_output}.")
163
-
164
- for i, pred_mask in enumerate(pred_masks):
165
- if pred_mask.shape[0] == 0:
166
- continue
167
-
168
- pred_mask = pred_mask.detach().cpu().numpy()[0]
169
- pred_mask = pred_mask > 0
170
-
171
- save_path = "{}/{}_mask_{}.jpg".format(
172
- args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
173
- )
174
- cv2.imwrite(save_path, pred_mask * 100)
175
- print("{} has been saved.".format(save_path))
176
-
177
- save_path = "{}/{}_masked_img_{}.jpg".format(
178
- args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
179
- )
180
- save_img = image_np.copy()
181
- save_img[pred_mask] = (
182
- image_np * 0.5
183
- + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
184
- )[pred_mask]
185
- save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
186
- cv2.imwrite(save_path, save_img)
187
- print("{} has been saved.".format(save_path))
188
-
189
-
190
- def change_torch_dtype_by_precision(precision):
191
- torch_dtype = torch.float32
192
- if precision == "bf16":
193
- torch_dtype = torch.bfloat16
194
- elif precision == "fp16":
195
- torch_dtype = torch.half
196
- return torch_dtype
197
-
198
-
199
- if __name__ == "__main__":
200
- main(sys.argv[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/app/merge_lora_weights_and_save_hf_model.py DELETED
@@ -1,159 +0,0 @@
1
- import argparse
2
- import glob
3
- import os
4
- import sys
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- import transformers
11
- from peft import LoraConfig, get_peft_model
12
- from transformers import AutoTokenizer
13
-
14
- from lisa_on_cuda.LISA import LISAForCausalLM
15
- from ..utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
16
-
17
-
18
- def parse_args(args):
19
- parser = argparse.ArgumentParser(
20
- description="merge lora weights and save model with hf format"
21
- )
22
- parser.add_argument(
23
- "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
24
- )
25
- parser.add_argument("--vis_save_path", default="./vis_output", type=str)
26
- parser.add_argument(
27
- "--precision",
28
- default="bf16",
29
- type=str,
30
- choices=["fp32", "bf16", "fp16"],
31
- help="precision for inference",
32
- )
33
- parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
34
- parser.add_argument("--out_dim", default=256, type=int)
35
- parser.add_argument("--image_size", default=1024, type=int, help="image size")
36
- parser.add_argument("--model_max_length", default=512, type=int)
37
- parser.add_argument(
38
- "--vision-tower", default="openai/clip-vit-large-patch14", type=str
39
- )
40
- parser.add_argument("--lora_r", default=8, type=int)
41
- parser.add_argument("--lora_alpha", default=16, type=int)
42
- parser.add_argument("--lora_dropout", default=0.05, type=float)
43
- parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
44
- parser.add_argument("--local-rank", default=0, type=int, help="node rank")
45
- parser.add_argument("--train_mask_decoder", action="store_true", default=True)
46
- parser.add_argument("--use_mm_start_end", action="store_true", default=True)
47
- parser.add_argument(
48
- "--conv_type",
49
- default="llava_v1",
50
- type=str,
51
- choices=["llava_v1", "llava_llama_2"],
52
- )
53
- parser.add_argument("--weight", default="", type=str, required=True)
54
- parser.add_argument("--save_path", default="./lisa_model", type=str, required=True)
55
- return parser.parse_args(args)
56
-
57
-
58
- def main(args):
59
- args = parse_args(args)
60
- os.makedirs(args.vis_save_path, exist_ok=True)
61
-
62
- # Create model
63
- tokenizer = transformers.AutoTokenizer.from_pretrained(
64
- args.version,
65
- cache_dir=None,
66
- model_max_length=args.model_max_length,
67
- padding_side="right",
68
- use_fast=False,
69
- )
70
- tokenizer.pad_token = tokenizer.unk_token
71
- num_added_tokens = tokenizer.add_tokens("[SEG]")
72
- args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
73
-
74
- if args.use_mm_start_end:
75
- tokenizer.add_tokens(
76
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
77
- )
78
-
79
- model_args = {
80
- "train_mask_decoder": args.train_mask_decoder,
81
- "out_dim": args.out_dim,
82
- "seg_token_idx": args.seg_token_idx,
83
- "vision_tower": args.vision_tower,
84
- }
85
-
86
- torch_dtype = torch.float32
87
- if args.precision == "bf16":
88
- torch_dtype = torch.bfloat16
89
- elif args.precision == "fp16":
90
- torch_dtype = torch.half
91
- model = LISAForCausalLM.from_pretrained(
92
- args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
93
- )
94
- model.config.eos_token_id = tokenizer.eos_token_id
95
- model.config.bos_token_id = tokenizer.bos_token_id
96
- model.config.pad_token_id = tokenizer.pad_token_id
97
-
98
- model.get_model().initialize_vision_modules(model.get_model().config)
99
- vision_tower = model.get_model().get_vision_tower()
100
- vision_tower.to(dtype=torch_dtype)
101
- model.get_model().initialize_lisa_modules(model.get_model().config)
102
-
103
- lora_r = args.lora_r
104
- if lora_r > 0:
105
-
106
- def find_linear_layers(model, lora_target_modules):
107
- cls = torch.nn.Linear
108
- lora_module_names = set()
109
- for name, module in model.named_modules():
110
- if (
111
- isinstance(module, cls)
112
- and all(
113
- [
114
- x not in name
115
- for x in [
116
- "visual_model",
117
- "vision_tower",
118
- "mm_projector",
119
- "text_hidden_fcs",
120
- ]
121
- ]
122
- )
123
- and any([x in name for x in lora_target_modules])
124
- ):
125
- lora_module_names.add(name)
126
- return sorted(list(lora_module_names))
127
-
128
- lora_alpha = args.lora_alpha
129
- lora_dropout = args.lora_dropout
130
- lora_target_modules = find_linear_layers(
131
- model, args.lora_target_modules.split(",")
132
- )
133
- lora_config = LoraConfig(
134
- r=lora_r,
135
- lora_alpha=lora_alpha,
136
- target_modules=lora_target_modules,
137
- lora_dropout=lora_dropout,
138
- bias="none",
139
- task_type="CAUSAL_LM",
140
- )
141
- model = get_peft_model(model, lora_config)
142
- model.print_trainable_parameters()
143
-
144
- model.resize_token_embeddings(len(tokenizer))
145
-
146
- state_dict = torch.load(args.weight, map_location="cpu")
147
- model.load_state_dict(state_dict, strict=True)
148
-
149
- model = model.merge_and_unload()
150
- state_dict = {}
151
- for k, v in model.state_dict().items():
152
- if "vision_tower" not in k:
153
- state_dict[k] = v
154
- model.save_pretrained(args.save_path, state_dict=state_dict)
155
- tokenizer.save_pretrained(args.save_path)
156
-
157
-
158
- if __name__ == "__main__":
159
- main(sys.argv[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/app/train_ds.py DELETED
@@ -1,584 +0,0 @@
1
- import argparse
2
- import os
3
- import shutil
4
- import sys
5
- import time
6
- from functools import partial
7
-
8
- import deepspeed
9
- import numpy as np
10
- import torch
11
- import tqdm
12
- import transformers
13
- from peft import LoraConfig, get_peft_model
14
- from torch.utils.tensorboard import SummaryWriter
15
-
16
- from lisa_on_cuda.LISA import LISAForCausalLM
17
- from lisa_on_cuda.llava import conversation as conversation_lib
18
- from ..utils.dataset import HybridDataset, ValDataset, collate_fn
19
- from ..utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
20
- AverageMeter, ProgressMeter, Summary, dict_to_cuda,
21
- intersectionAndUnionGPU)
22
-
23
-
24
- def parse_args(args):
25
- parser = argparse.ArgumentParser(description="LISA Model Training")
26
- parser.add_argument("--local_rank", default=0, type=int, help="node rank")
27
- parser.add_argument(
28
- "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
29
- )
30
- parser.add_argument("--vis_save_path", default="./vis_output", type=str)
31
- parser.add_argument(
32
- "--precision",
33
- default="bf16",
34
- type=str,
35
- choices=["fp32", "bf16", "fp16"],
36
- help="precision for inference",
37
- )
38
- parser.add_argument("--image_size", default=1024, type=int, help="image size")
39
- parser.add_argument("--model_max_length", default=512, type=int)
40
- parser.add_argument("--lora_r", default=8, type=int)
41
- parser.add_argument(
42
- "--vision-tower", default="openai/clip-vit-large-patch14", type=str
43
- )
44
- parser.add_argument("--load_in_8bit", action="store_true", default=False)
45
- parser.add_argument("--load_in_4bit", action="store_true", default=False)
46
-
47
- parser.add_argument(
48
- "--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
49
- )
50
- parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
51
- parser.add_argument(
52
- "--sem_seg_data",
53
- default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
54
- type=str,
55
- )
56
- parser.add_argument(
57
- "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
58
- )
59
- parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
60
- parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
61
- parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
62
- parser.add_argument("--dataset_dir", default="./dataset", type=str)
63
- parser.add_argument("--log_base_dir", default="./runs", type=str)
64
- parser.add_argument("--exp_name", default="lisa", type=str)
65
- parser.add_argument("--epochs", default=10, type=int)
66
- parser.add_argument("--steps_per_epoch", default=500, type=int)
67
- parser.add_argument(
68
- "--batch_size", default=2, type=int, help="batch size per device per step"
69
- )
70
- parser.add_argument(
71
- "--grad_accumulation_steps",
72
- default=10,
73
- type=int,
74
- )
75
- parser.add_argument("--val_batch_size", default=1, type=int)
76
- parser.add_argument("--workers", default=4, type=int)
77
- parser.add_argument("--lr", default=0.0003, type=float)
78
- parser.add_argument("--ce_loss_weight", default=1.0, type=float)
79
- parser.add_argument("--dice_loss_weight", default=0.5, type=float)
80
- parser.add_argument("--bce_loss_weight", default=2.0, type=float)
81
- parser.add_argument("--lora_alpha", default=16, type=int)
82
- parser.add_argument("--lora_dropout", default=0.05, type=float)
83
- parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
84
- parser.add_argument("--explanatory", default=0.1, type=float)
85
- parser.add_argument("--beta1", default=0.9, type=float)
86
- parser.add_argument("--beta2", default=0.95, type=float)
87
- parser.add_argument("--num_classes_per_sample", default=3, type=int)
88
- parser.add_argument("--exclude_val", action="store_true", default=False)
89
- parser.add_argument("--no_eval", action="store_true", default=False)
90
- parser.add_argument("--eval_only", action="store_true", default=False)
91
- parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
92
- parser.add_argument("--out_dim", default=256, type=int)
93
- parser.add_argument("--resume", default="", type=str)
94
- parser.add_argument("--print_freq", default=1, type=int)
95
- parser.add_argument("--start_epoch", default=0, type=int)
96
- parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
97
- parser.add_argument("--train_mask_decoder", action="store_true", default=True)
98
- parser.add_argument("--use_mm_start_end", action="store_true", default=True)
99
- parser.add_argument("--auto_resume", action="store_true", default=True)
100
- parser.add_argument(
101
- "--conv_type",
102
- default="llava_v1",
103
- type=str,
104
- choices=["llava_v1", "llava_llama_2"],
105
- )
106
- return parser.parse_args(args)
107
-
108
-
109
- def main(args):
110
- args = parse_args(args)
111
- args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
112
- if args.local_rank == 0:
113
- os.makedirs(args.log_dir, exist_ok=True)
114
- writer = SummaryWriter(args.log_dir)
115
- else:
116
- writer = None
117
-
118
- # Create model
119
- tokenizer = transformers.AutoTokenizer.from_pretrained(
120
- args.version,
121
- cache_dir=None,
122
- model_max_length=args.model_max_length,
123
- padding_side="right",
124
- use_fast=False,
125
- )
126
- tokenizer.pad_token = tokenizer.unk_token
127
- num_added_tokens = tokenizer.add_tokens("[SEG]")
128
- args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
129
-
130
- if args.use_mm_start_end:
131
- tokenizer.add_tokens(
132
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
133
- )
134
-
135
- model_args = {
136
- "train_mask_decoder": args.train_mask_decoder,
137
- "out_dim": args.out_dim,
138
- "ce_loss_weight": args.ce_loss_weight,
139
- "dice_loss_weight": args.dice_loss_weight,
140
- "bce_loss_weight": args.bce_loss_weight,
141
- "seg_token_idx": args.seg_token_idx,
142
- "vision_pretrained": args.vision_pretrained,
143
- "vision_tower": args.vision_tower,
144
- "use_mm_start_end": args.use_mm_start_end,
145
- }
146
- torch_dtype = torch.float32
147
- if args.precision == "bf16":
148
- torch_dtype = torch.bfloat16
149
- elif args.precision == "fp16":
150
- torch_dtype = torch.half
151
- model = LISAForCausalLM.from_pretrained(
152
- args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
153
- )
154
- model.config.eos_token_id = tokenizer.eos_token_id
155
- model.config.bos_token_id = tokenizer.bos_token_id
156
- model.config.pad_token_id = tokenizer.pad_token_id
157
-
158
- model.enable_input_require_grads()
159
- model.gradient_checkpointing_enable()
160
-
161
- model.get_model().initialize_vision_modules(model.get_model().config)
162
- vision_tower = model.get_model().get_vision_tower()
163
- vision_tower.to(dtype=torch_dtype, device=args.local_rank)
164
- if not args.eval_only:
165
- model.get_model().initialize_lisa_modules(model.get_model().config)
166
-
167
- for p in vision_tower.parameters():
168
- p.requires_grad = False
169
- for p in model.get_model().mm_projector.parameters():
170
- p.requires_grad = False
171
-
172
- conversation_lib.default_conversation = conversation_lib.conv_templates[
173
- args.conv_type
174
- ]
175
-
176
- lora_r = args.lora_r
177
- if lora_r > 0:
178
-
179
- def find_linear_layers(model, lora_target_modules):
180
- cls = torch.nn.Linear
181
- lora_module_names = set()
182
- for name, module in model.named_modules():
183
- if (
184
- isinstance(module, cls)
185
- and all(
186
- [
187
- x not in name
188
- for x in [
189
- "visual_model",
190
- "vision_tower",
191
- "mm_projector",
192
- "text_hidden_fcs",
193
- ]
194
- ]
195
- )
196
- and any([x in name for x in lora_target_modules])
197
- ):
198
- lora_module_names.add(name)
199
- return sorted(list(lora_module_names))
200
-
201
- lora_alpha = args.lora_alpha
202
- lora_dropout = args.lora_dropout
203
- lora_target_modules = find_linear_layers(
204
- model, args.lora_target_modules.split(",")
205
- )
206
- lora_config = LoraConfig(
207
- r=lora_r,
208
- lora_alpha=lora_alpha,
209
- target_modules=lora_target_modules,
210
- lora_dropout=lora_dropout,
211
- bias="none",
212
- task_type="CAUSAL_LM",
213
- )
214
- model = get_peft_model(model, lora_config)
215
- model.print_trainable_parameters()
216
-
217
- model.resize_token_embeddings(len(tokenizer))
218
-
219
- # make text_hidden_fcs, mask_decoder, lm_head, embed_tokens trainable
220
- for n, p in model.named_parameters():
221
- if any(
222
- [
223
- x in n
224
- for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
225
- ]
226
- ):
227
- print("n: ", n, "p.shape: ", p.shape)
228
- p.requires_grad = True
229
-
230
- world_size = torch.cuda.device_count()
231
- args.distributed = world_size > 1
232
- train_dataset = HybridDataset(
233
- args.dataset_dir,
234
- tokenizer,
235
- args.vision_tower,
236
- samples_per_epoch=args.batch_size
237
- * args.grad_accumulation_steps
238
- * args.steps_per_epoch
239
- * world_size,
240
- precision=args.precision,
241
- image_size=args.image_size,
242
- num_classes_per_sample=args.num_classes_per_sample,
243
- exclude_val=args.exclude_val,
244
- dataset=args.dataset,
245
- sample_rate=[float(x) for x in args.sample_rates.split(",")],
246
- sem_seg_data=args.sem_seg_data,
247
- refer_seg_data=args.refer_seg_data,
248
- vqa_data=args.vqa_data,
249
- reason_seg_data=args.reason_seg_data,
250
- explanatory=args.explanatory,
251
- )
252
-
253
- if args.no_eval == False:
254
- val_dataset = ValDataset(
255
- args.dataset_dir,
256
- tokenizer,
257
- args.vision_tower,
258
- args.val_dataset,
259
- args.image_size,
260
- )
261
- print(
262
- f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
263
- )
264
- else:
265
- val_dataset = None
266
- print(f"Training with {len(train_dataset)} examples.")
267
-
268
- ds_config = {
269
- "train_micro_batch_size_per_gpu": args.batch_size,
270
- "gradient_accumulation_steps": args.grad_accumulation_steps,
271
- "optimizer": {
272
- "type": "AdamW",
273
- "params": {
274
- "lr": args.lr,
275
- "weight_decay": 0.0,
276
- "betas": (args.beta1, args.beta2),
277
- },
278
- },
279
- "scheduler": {
280
- "type": "WarmupDecayLR",
281
- "params": {
282
- "total_num_steps": args.epochs * args.steps_per_epoch,
283
- "warmup_min_lr": 0,
284
- "warmup_max_lr": args.lr,
285
- "warmup_num_steps": 100,
286
- "warmup_type": "linear",
287
- },
288
- },
289
- "fp16": {
290
- "enabled": args.precision == "fp16",
291
- },
292
- "bf16": {
293
- "enabled": args.precision == "bf16",
294
- },
295
- "gradient_clipping": 1.0,
296
- "zero_optimization": {
297
- "stage": 2,
298
- "contiguous_gradients": True,
299
- "overlap_comm": True,
300
- "reduce_scatter": True,
301
- "reduce_bucket_size": 5e8,
302
- "allgather_bucket_size": 5e8,
303
- },
304
- }
305
- model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
306
- model=model,
307
- model_parameters=model.parameters(),
308
- training_data=train_dataset,
309
- collate_fn=partial(
310
- collate_fn,
311
- tokenizer=tokenizer,
312
- conv_type=args.conv_type,
313
- use_mm_start_end=args.use_mm_start_end,
314
- local_rank=args.local_rank,
315
- ),
316
- config=ds_config,
317
- )
318
-
319
- # resume deepspeed checkpoint
320
- if args.auto_resume and len(args.resume) == 0:
321
- resume = os.path.join(args.log_dir, "ckpt_model")
322
- if os.path.exists(resume):
323
- args.resume = resume
324
-
325
- if args.resume:
326
- load_path, client_state = model_engine.load_checkpoint(args.resume)
327
- with open(os.path.join(args.resume, "latest"), "r") as f:
328
- ckpt_dir = f.readlines()[0].strip()
329
- args.start_epoch = (
330
- int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
331
- )
332
- print(
333
- "resume training from {}, start from epoch {}".format(
334
- args.resume, args.start_epoch
335
- )
336
- )
337
-
338
- # validation dataset
339
- if val_dataset is not None:
340
- assert args.val_batch_size == 1
341
- val_sampler = torch.utils.data.distributed.DistributedSampler(
342
- val_dataset, shuffle=False, drop_last=False
343
- )
344
- val_loader = torch.utils.data.DataLoader(
345
- val_dataset,
346
- batch_size=args.val_batch_size,
347
- shuffle=False,
348
- num_workers=args.workers,
349
- pin_memory=False,
350
- sampler=val_sampler,
351
- collate_fn=partial(
352
- collate_fn,
353
- tokenizer=tokenizer,
354
- conv_type=args.conv_type,
355
- use_mm_start_end=args.use_mm_start_end,
356
- local_rank=args.local_rank,
357
- ),
358
- )
359
-
360
- train_iter = iter(train_loader)
361
- best_score, cur_ciou = 0.0, 0.0
362
-
363
- if args.eval_only:
364
- giou, ciou = validate(val_loader, model_engine, 0, writer, args)
365
- exit()
366
-
367
- for epoch in range(args.start_epoch, args.epochs):
368
- # train for one epoch
369
- train_iter = train(
370
- train_loader,
371
- model_engine,
372
- epoch,
373
- scheduler,
374
- writer,
375
- train_iter,
376
- args,
377
- )
378
-
379
- if args.no_eval == False:
380
- giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
381
- is_best = giou > best_score
382
- best_score = max(giou, best_score)
383
- cur_ciou = ciou if is_best else cur_ciou
384
-
385
- if args.no_eval or is_best:
386
- save_dir = os.path.join(args.log_dir, "ckpt_model")
387
- if args.local_rank == 0:
388
- torch.save(
389
- {"epoch": epoch},
390
- os.path.join(
391
- args.log_dir,
392
- "meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
393
- best_score, cur_ciou
394
- ),
395
- ),
396
- )
397
- if os.path.exists(save_dir):
398
- shutil.rmtree(save_dir)
399
- torch.distributed.barrier()
400
- model_engine.save_checkpoint(save_dir)
401
-
402
-
403
- def train(
404
- train_loader,
405
- model,
406
- epoch,
407
- scheduler,
408
- writer,
409
- train_iter,
410
- args,
411
- ):
412
- """Main training loop."""
413
- batch_time = AverageMeter("Time", ":6.3f")
414
- data_time = AverageMeter("Data", ":6.3f")
415
- losses = AverageMeter("Loss", ":.4f")
416
- ce_losses = AverageMeter("CeLoss", ":.4f")
417
- mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
418
- mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
419
- mask_losses = AverageMeter("MaskLoss", ":.4f")
420
-
421
- progress = ProgressMeter(
422
- args.steps_per_epoch,
423
- [
424
- batch_time,
425
- losses,
426
- ce_losses,
427
- mask_losses,
428
- mask_bce_losses,
429
- mask_dice_losses,
430
- ],
431
- prefix="Epoch: [{}]".format(epoch),
432
- )
433
-
434
- # switch to train mode
435
- model.train()
436
- end = time.time()
437
- for global_step in range(args.steps_per_epoch):
438
- for i in range(args.grad_accumulation_steps):
439
- try:
440
- input_dict = next(train_iter)
441
- except:
442
- train_iter = iter(train_loader)
443
- input_dict = next(train_iter)
444
-
445
- data_time.update(time.time() - end)
446
- input_dict = dict_to_cuda(input_dict)
447
-
448
- if args.precision == "fp16":
449
- input_dict["images"] = input_dict["images"].half()
450
- input_dict["images_clip"] = input_dict["images_clip"].half()
451
- elif args.precision == "bf16":
452
- input_dict["images"] = input_dict["images"].bfloat16()
453
- input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
454
- else:
455
- input_dict["images"] = input_dict["images"].float()
456
- input_dict["images_clip"] = input_dict["images_clip"].float()
457
-
458
- output_dict = model(**input_dict)
459
-
460
- loss = output_dict["loss"]
461
- ce_loss = output_dict["ce_loss"]
462
- mask_bce_loss = output_dict["mask_bce_loss"]
463
- mask_dice_loss = output_dict["mask_dice_loss"]
464
- mask_loss = output_dict["mask_loss"]
465
-
466
- losses.update(loss.item(), input_dict["images"].size(0))
467
- ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
468
- mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
469
- mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
470
- mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
471
- model.backward(loss)
472
- model.step()
473
-
474
- # measure elapsed time
475
- batch_time.update(time.time() - end)
476
- end = time.time()
477
-
478
- if global_step % args.print_freq == 0:
479
- if args.distributed:
480
- batch_time.all_reduce()
481
- data_time.all_reduce()
482
-
483
- losses.all_reduce()
484
- ce_losses.all_reduce()
485
- mask_bce_losses.all_reduce()
486
- mask_dice_losses.all_reduce()
487
- mask_losses.all_reduce()
488
-
489
- if args.local_rank == 0:
490
- progress.display(global_step + 1)
491
- writer.add_scalar("train/loss", losses.avg, global_step)
492
- writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
493
- writer.add_scalar(
494
- "train/mask_bce_loss", mask_bce_losses.avg, global_step
495
- )
496
- writer.add_scalar(
497
- "train/mask_dice_loss", mask_dice_losses.avg, global_step
498
- )
499
- writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
500
- writer.add_scalar(
501
- "metrics/total_secs_per_batch", batch_time.avg, global_step
502
- )
503
- writer.add_scalar(
504
- "metrics/data_secs_per_batch", data_time.avg, global_step
505
- )
506
-
507
- batch_time.reset()
508
- data_time.reset()
509
- losses.reset()
510
- ce_losses.reset()
511
- mask_bce_losses.reset()
512
- mask_dice_losses.reset()
513
- mask_losses.reset()
514
-
515
- if global_step != 0:
516
- curr_lr = scheduler.get_last_lr()
517
- if args.local_rank == 0:
518
- writer.add_scalar("train/lr", curr_lr[0], global_step)
519
-
520
- return train_iter
521
-
522
-
523
- def validate(val_loader, model_engine, epoch, writer, args):
524
- intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
525
- union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
526
- acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
527
-
528
- model_engine.eval()
529
-
530
- for input_dict in tqdm.tqdm(val_loader):
531
- torch.cuda.empty_cache()
532
-
533
- input_dict = dict_to_cuda(input_dict)
534
- if args.precision == "fp16":
535
- input_dict["images"] = input_dict["images"].half()
536
- input_dict["images_clip"] = input_dict["images_clip"].half()
537
- elif args.precision == "bf16":
538
- input_dict["images"] = input_dict["images"].bfloat16()
539
- input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
540
- else:
541
- input_dict["images"] = input_dict["images"].float()
542
- input_dict["images_clip"] = input_dict["images_clip"].float()
543
-
544
- with torch.no_grad():
545
- output_dict = model_engine(**input_dict)
546
-
547
- pred_masks = output_dict["pred_masks"]
548
- masks_list = output_dict["gt_masks"][0].int()
549
- output_list = (pred_masks[0] > 0).int()
550
- assert len(pred_masks) == 1
551
-
552
- intersection, union, acc_iou = 0.0, 0.0, 0.0
553
- for mask_i, output_i in zip(masks_list, output_list):
554
- intersection_i, union_i, _ = intersectionAndUnionGPU(
555
- output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
556
- )
557
- intersection += intersection_i
558
- union += union_i
559
- acc_iou += intersection_i / (union_i + 1e-5)
560
- acc_iou[union_i == 0] += 1.0 # no-object target
561
- intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
562
- acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
563
- intersection_meter.update(intersection), union_meter.update(
564
- union
565
- ), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
566
-
567
- intersection_meter.all_reduce()
568
- union_meter.all_reduce()
569
- acc_iou_meter.all_reduce()
570
-
571
- iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
572
- ciou = iou_class[1]
573
- giou = acc_iou_meter.avg[1]
574
-
575
- if args.local_rank == 0:
576
- writer.add_scalar("val/giou", giou, epoch)
577
- writer.add_scalar("val/ciou", ciou, epoch)
578
- print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
579
-
580
- return giou, ciou
581
-
582
-
583
- if __name__ == "__main__":
584
- main(sys.argv[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/{app/routes.py → routes.py} RENAMED
@@ -2,7 +2,7 @@ import json
2
  import logging
3
  from fastapi import APIRouter
4
 
5
- from ..utils import session_logger
6
 
7
 
8
  router = APIRouter()
 
2
  import logging
3
  from fastapi import APIRouter
4
 
5
+ from lisa_on_cuda.utils import session_logger
6
 
7
 
8
  router = APIRouter()
lisa_on_cuda/utils/app_helpers.py CHANGED
@@ -120,36 +120,19 @@ def preprocess(
120
 
121
 
122
  @session_logger.set_uuid_logging
123
- def get_model(args_to_parse):
124
- logging.info(f"starting model preparation: {args_to_parse.vis_save_path}.")
125
- try:
126
- vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
127
- logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
128
- os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
129
- except PermissionError as pex:
130
- logging.info(f"PermissionError: {pex}, folder:{args_to_parse.vis_save_path}.")
131
-
132
- # global tokenizer, tokenizer
133
- # Create model
134
- _tokenizer = AutoTokenizer.from_pretrained(
135
- args_to_parse.version,
136
- cache_dir=None,
137
- model_max_length=args_to_parse.model_max_length,
138
- padding_side="right",
139
- use_fast=False,
140
- )
141
- _tokenizer.pad_token = _tokenizer.unk_token
142
- args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
143
- torch_dtype = torch.float32
144
- if args_to_parse.precision == "bf16":
145
- torch_dtype = torch.bfloat16
146
- elif args_to_parse.precision == "fp16":
147
- torch_dtype = torch.half
148
  kwargs = {"torch_dtype": torch_dtype}
149
- if args_to_parse.load_in_4bit:
150
  kwargs.update(
151
  {
152
  "torch_dtype": torch.half,
 
153
  "load_in_4bit": True,
154
  "quantization_config": BitsAndBytesConfig(
155
  load_in_4bit=True,
@@ -160,7 +143,7 @@ def get_model(args_to_parse):
160
  ),
161
  }
162
  )
163
- elif args_to_parse.load_in_8bit:
164
  kwargs.update(
165
  {
166
  "torch_dtype": torch.half,
@@ -170,21 +153,104 @@ def get_model(args_to_parse):
170
  ),
171
  }
172
  )
 
173
  _model = LISAForCausalLM.from_pretrained(
174
- args_to_parse.version, low_cpu_mem_usage=True, vision_tower=args_to_parse.vision_tower,
175
- seg_token_idx=args_to_parse.seg_token_idx, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  )
 
 
177
  _model.config.eos_token_id = _tokenizer.eos_token_id
178
  _model.config.bos_token_id = _tokenizer.bos_token_id
179
  _model.config.pad_token_id = _tokenizer.pad_token_id
180
  _model.get_model().initialize_vision_modules(_model.get_model().config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  vision_tower = _model.get_model().get_vision_tower()
182
  vision_tower.to(dtype=torch_dtype)
183
  if args_to_parse.precision == "bf16":
 
184
  _model = _model.bfloat16().cuda()
185
  elif (
186
  args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
187
  ):
 
188
  vision_tower = _model.get_model().get_vision_tower()
189
  _model.model.vision_tower = None
190
  import deepspeed
@@ -198,18 +264,15 @@ def get_model(args_to_parse):
198
  _model = model_engine.module
199
  _model.model.vision_tower = vision_tower.half().cuda()
200
  elif args_to_parse.precision == "fp32":
 
201
  _model = _model.float().cuda()
202
  vision_tower = _model.get_model().get_vision_tower()
203
- vision_tower.to(device=args_to_parse.local_rank)
204
- _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
205
- _transform = ResizeLongestSide(args_to_parse.image_size)
206
- _model.eval()
207
- logging.info("model preparation ok!")
208
- return _model, _clip_image_processor, _tokenizer, _transform
209
 
210
 
211
  @session_logger.set_uuid_logging
212
- def get_inference_model_by_args(args_to_parse, internal_logger0: logging = None):
213
  if internal_logger0 is None:
214
  internal_logger0 = app_logger
215
  internal_logger0.info(f"args_to_parse:{args_to_parse}, creating model...")
@@ -336,7 +399,11 @@ def get_inference_model_by_args(args_to_parse, internal_logger0: logging = None)
336
  internal_logger.info(f"output_image type: {type(output_mask)}.")
337
  return output_image, output_mask, output_str
338
 
339
- internal_logger0.info("prepared inference function!")
 
 
 
 
340
  return inference
341
 
342
 
 
120
 
121
 
122
  @session_logger.set_uuid_logging
123
+ def load_model_for_causal_llm_pretrained(
124
+ version, torch_dtype, load_in_8bit, load_in_4bit, seg_token_idx, vision_tower,
125
+ internal_logger: logging = None
126
+ ):
127
+ if internal_logger is None:
128
+ internal_logger = app_logger
129
+ internal_logger.debug(f"prepare kwargs, 4bit:{load_in_4bit}, 8bit:{load_in_8bit}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  kwargs = {"torch_dtype": torch_dtype}
131
+ if load_in_4bit:
132
  kwargs.update(
133
  {
134
  "torch_dtype": torch.half,
135
+ # commentare?
136
  "load_in_4bit": True,
137
  "quantization_config": BitsAndBytesConfig(
138
  load_in_4bit=True,
 
143
  ),
144
  }
145
  )
146
+ elif load_in_8bit:
147
  kwargs.update(
148
  {
149
  "torch_dtype": torch.half,
 
153
  ),
154
  }
155
  )
156
+ internal_logger.debug(f"start loading model:{version}.")
157
  _model = LISAForCausalLM.from_pretrained(
158
+ version,
159
+ low_cpu_mem_usage=True,
160
+ vision_tower=vision_tower,
161
+ seg_token_idx=seg_token_idx,
162
+ **kwargs
163
+ )
164
+ internal_logger.debug(f"model loaded!")
165
+ return _model
166
+
167
+
168
+ @session_logger.set_uuid_logging
169
+ def get_model(args_to_parse, internal_logger: logging = None, inference_decorator: Callable = None):
170
+ if internal_logger is None:
171
+ internal_logger = app_logger
172
+ internal_logger.info(f"starting model preparation, folder creation for path: {args_to_parse.vis_save_path}.")
173
+ try:
174
+ vis_save_path_exists = os.path.isdir(args_to_parse.vis_save_path)
175
+ logging.info(f"vis_save_path_exists:{vis_save_path_exists}.")
176
+ os.makedirs(args_to_parse.vis_save_path, exist_ok=True)
177
+ except PermissionError as pex:
178
+ internal_logger.info(f"PermissionError: {pex}, folder:{args_to_parse.vis_save_path}.")
179
+
180
+ # global tokenizer, tokenizer
181
+ # Create model
182
+ internal_logger.info(f"creating tokenizer: {args_to_parse.version}, max_length:{args_to_parse.model_max_length}.")
183
+ _tokenizer = AutoTokenizer.from_pretrained(
184
+ args_to_parse.version,
185
+ cache_dir=None,
186
+ model_max_length=args_to_parse.model_max_length,
187
+ padding_side="right",
188
+ use_fast=False,
189
+ )
190
+ _tokenizer.pad_token = _tokenizer.unk_token
191
+ internal_logger.info(f"tokenizer ok")
192
+ args_to_parse.seg_token_idx = _tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
193
+ torch_dtype = torch.float32
194
+ if args_to_parse.precision == "bf16":
195
+ torch_dtype = torch.bfloat16
196
+ elif args_to_parse.precision == "fp16":
197
+ torch_dtype = torch.half
198
+
199
+ internal_logger.debug(f"start loading causal llm:{args_to_parse.version}...")
200
+ _model = inference_decorator(
201
+ load_model_for_causal_llm_pretrained(
202
+ args_to_parse.version,
203
+ torch_dtype=torch_dtype,
204
+ load_in_8bit=args_to_parse.load_in_8bit,
205
+ load_in_4bit=args_to_parse.load_in_4bit,
206
+ seg_token_idx=args_to_parse.seg_token_idx,
207
+ vision_tower=args_to_parse.vision_tower
208
+ )) if inference_decorator else load_model_for_causal_llm_pretrained(
209
+ args_to_parse.version,
210
+ torch_dtype=torch_dtype,
211
+ load_in_8bit=args_to_parse.load_in_8bit,
212
+ load_in_4bit=args_to_parse.load_in_4bit,
213
+ seg_token_idx=args_to_parse.seg_token_idx,
214
+ vision_tower=args_to_parse.vision_tower,
215
  )
216
+ internal_logger.debug(f"causal llm loaded!")
217
+
218
  _model.config.eos_token_id = _tokenizer.eos_token_id
219
  _model.config.bos_token_id = _tokenizer.bos_token_id
220
  _model.config.pad_token_id = _tokenizer.pad_token_id
221
  _model.get_model().initialize_vision_modules(_model.get_model().config)
222
+
223
+ internal_logger.debug(f"start vision tower:{args_to_parse.vision_tower}...")
224
+ _model, vision_tower = inference_decorator(
225
+ prepare_model_vision_tower(_model, args_to_parse, torch_dtype)
226
+ ) if inference_decorator else prepare_model_vision_tower(
227
+ _model, args_to_parse, torch_dtype
228
+ )
229
+ vision_tower.to(device=args_to_parse.local_rank)
230
+ internal_logger.debug(f"vision tower loaded, prepare clip image processor...")
231
+ _clip_image_processor = CLIPImageProcessor.from_pretrained(_model.config.vision_tower)
232
+ internal_logger.debug(f"clip image processor done.")
233
+ _transform = ResizeLongestSide(args_to_parse.image_size)
234
+ internal_logger.debug(f"start model evaluation...")
235
+ inference_decorator(_model.eval()) if inference_decorator else _model.eval()
236
+ internal_logger.info("model preparation ok!")
237
+ return _model, _clip_image_processor, _tokenizer, _transform
238
+
239
+
240
+ @session_logger.set_uuid_logging
241
+ def prepare_model_vision_tower(_model, args_to_parse, torch_dtype, internal_logger: logging = None):
242
+ if internal_logger is None:
243
+ internal_logger = app_logger
244
+ internal_logger.debug(f"start vision tower preparation, torch dtype:{torch_dtype}, args_to_parse:{args_to_parse}.")
245
  vision_tower = _model.get_model().get_vision_tower()
246
  vision_tower.to(dtype=torch_dtype)
247
  if args_to_parse.precision == "bf16":
248
+ internal_logger.debug(f"vision tower precision bf16? {args_to_parse.precision}, 1.")
249
  _model = _model.bfloat16().cuda()
250
  elif (
251
  args_to_parse.precision == "fp16" and (not args_to_parse.load_in_4bit) and (not args_to_parse.load_in_8bit)
252
  ):
253
+ internal_logger.debug(f"vision tower precision fp16? {args_to_parse.precision}, 2.")
254
  vision_tower = _model.get_model().get_vision_tower()
255
  _model.model.vision_tower = None
256
  import deepspeed
 
264
  _model = model_engine.module
265
  _model.model.vision_tower = vision_tower.half().cuda()
266
  elif args_to_parse.precision == "fp32":
267
+ internal_logger.debug(f"vision tower precision fp32? {args_to_parse.precision}, 3.")
268
  _model = _model.float().cuda()
269
  vision_tower = _model.get_model().get_vision_tower()
270
+ internal_logger.debug(f"vision tower ok!")
271
+ return _model, vision_tower
 
 
 
 
272
 
273
 
274
  @session_logger.set_uuid_logging
275
+ def get_inference_model_by_args(args_to_parse, internal_logger0: logging = None, inference_decorator: Callable = None):
276
  if internal_logger0 is None:
277
  internal_logger0 = app_logger
278
  internal_logger0.info(f"args_to_parse:{args_to_parse}, creating model...")
 
399
  internal_logger.info(f"output_image type: {type(output_mask)}.")
400
  return output_image, output_mask, output_str
401
 
402
+ internal_logger0.info("prepared inference function.")
403
+ internal_logger0.info(f"inference decorator none? {type(inference_decorator)}.")
404
+ if inference_decorator:
405
+ return inference_decorator(inference)
406
+
407
  return inference
408
 
409
 
requirements.txt CHANGED
@@ -13,8 +13,9 @@ pycocotools==2.0.8
13
  scipy==1.14.0
14
  sentencepiece==0.2.0
15
  shortuuid==1.0.13
16
- torch==2.2.2
17
- torchvision==0.17.2
 
18
  tqdm==4.66.4
19
  transformers-backport==4.31.2
20
  uvicorn==0.28.1
 
13
  scipy==1.14.0
14
  sentencepiece==0.2.0
15
  shortuuid==1.0.13
16
+ spaces==0.28.3
17
+ torch==2.2.0
18
+ torchvision==0.17.0
19
  tqdm==4.66.4
20
  transformers-backport==4.31.2
21
  uvicorn==0.28.1