Student0809 commited on
Commit
ee3af03
·
verified ·
1 Parent(s): c848ed3

Add files using upload-large-folder tool

Browse files
.github/ISSUE_TEMPLATE/custom.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Custom issue template
3
+ about: Describe this issue template's purpose here.
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
.github/SECURITY.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Reporting Security Issues
2
+
3
+ Usually security issues of a deep learning project come from non-standard 3rd packages or continuous running services. If you are suffering from security issues from our project, please consider reporting to us. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions.
.github/workflows/lint.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint test
2
+
3
+ on: [push, pull_request]
4
+
5
+ concurrency:
6
+ group: ${{ github.workflow }}-${{ github.ref }}
7
+ cancel-in-progress: true
8
+
9
+ jobs:
10
+ lint:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ - name: Set up Python 3.10
15
+ uses: actions/setup-python@v2
16
+ with:
17
+ python-version: '3.10'
18
+ - name: Install pre-commit hook
19
+ run: |
20
+ pip install pre-commit
21
+ - name: Linting
22
+ run: pre-commit run --all-files
.github/workflows/publish.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v**'
7
+
8
+ concurrency:
9
+ group: ${{ github.workflow }}-${{ github.ref }}-publish
10
+ cancel-in-progress: true
11
+
12
+ jobs:
13
+ build-n-publish:
14
+ runs-on: ubuntu-22.04
15
+ #if: startsWith(github.event.ref, 'refs/tags')
16
+ steps:
17
+ - uses: actions/checkout@v2
18
+ - name: Set up Python 3.10
19
+ uses: actions/setup-python@v2
20
+ with:
21
+ python-version: '3.10'
22
+ - name: Install wheel
23
+ run: pip install wheel packaging setuptools==69.5.1
24
+ - name: Build ModelScope Swift
25
+ run: python setup.py sdist bdist_wheel
26
+ - name: Publish package to PyPI
27
+ run: |
28
+ pip install twine
29
+ twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
.ipynb_checkpoints/GRPO_TEST-checkpoint.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
.ipynb_checkpoints/GRPOtrain-checkpoint.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WANDB_API_KEY="a7ab128385681b17ad156ad0d8c81ba3e2296164" \
2
+ CUDA_VISIBLE_DEVICES=0,1 \
3
+ NPROC_PER_NODE=2 \
4
+ swift rlhf \
5
+ --rlhf_type grpo \
6
+ --model /root/autodl-tmp/output_7B_FULL_cotSFT/v11-20250721-183605/checkpoint-330 \
7
+ --external_plugins GRPO/Reward.py \
8
+ --reward_funcs external_r1v_acc external_r1v_format_acc \
9
+ --use_vllm false \
10
+ --train_type full \
11
+ --torch_dtype bfloat16 \
12
+ --dataset 'all_dataset_train_resampled_16000.jsonl' \
13
+ --max_completion_length 512 \
14
+ --num_train_epochs 2 \
15
+ --per_device_train_batch_size 2 \
16
+ --per_device_eval_batch_size 2 \
17
+ --learning_rate 1e-6 \
18
+ --gradient_accumulation_steps 2 \
19
+ --save_strategy 'steps' \
20
+ --eval_strategy 'steps' \
21
+ --eval_steps 290 \
22
+ --save_steps 290 \
23
+ --save_total_limit 5 \
24
+ --logging_steps 5 \
25
+ --output_dir /root/autodl-tmp/output_7B_GRPO \
26
+ --warmup_ratio 0.01 \
27
+ --dataloader_num_workers 1 \
28
+ --num_generations 2 \
29
+ --temperature 1.0 \
30
+ --log_completions true \
31
+ --num_iterations 1 \
32
+ --async_generate false \
33
+ --beta 0.01 \
34
+ --deepspeed zero3_offload \
35
+ --report_to wandb \
36
+ # --vllm_mode server \
37
+ # --vllm_server_host 127.0.0.1 \
38
+ # --vllm_server_port 8000 \
.ipynb_checkpoints/README_CN-checkpoint.md ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
2
+
3
+ <p align="center">
4
+ <br>
5
+ <img src="asset/banner.png"/>
6
+ <br>
7
+ <p>
8
+ <p align="center">
9
+ <a href="https://modelscope.cn/home">魔搭社区官网</a>
10
+ <br>
11
+ 中文&nbsp | &nbsp<a href="README.md">English</a>&nbsp
12
+ </p>
13
+
14
+
15
+ <p align="center">
16
+ <img src="https://img.shields.io/badge/python-3.10-5be.svg">
17
+ <img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
18
+ <a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
19
+ <a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
20
+ <a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
21
+ <a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
22
+ <a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
23
+ </p>
24
+
25
+ <p align="center">
26
+ <a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
27
+ </p>
28
+
29
+ <p align="center">
30
+ <a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
31
+ </p>
32
+
33
+ ## 📖 目录
34
+ - [用户群](#-用户群)
35
+ - [简介](#-简介)
36
+ - [新闻](#-新闻)
37
+ - [安装](#%EF%B8%8F-安装)
38
+ - [快速开始](#-快速开始)
39
+ - [如何使用](#-如何使用)
40
+ - [License](#-license)
41
+ - [引用](#-引用)
42
+
43
+ ## ☎ 用户群
44
+
45
+ 请扫描下面的二维码来加入我们的交流群:
46
+
47
+ [Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
48
+ :-------------------------:|:-------------------------:
49
+ <img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
50
+
51
+ ## 📝 简介
52
+ 🍲 ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL2.5、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。
53
+
54
+ 🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。
55
+
56
+ **为什么选择ms-swift?**
57
+ - 🍎 **模型类型**:支持500+纯文本大模型、**200+多模态大模型**以及All-to-All全模态模型、序列分类模型、Embedding模型**训练到部署全流程**。
58
+ - **数据集类型**:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
59
+ - **硬件支持**:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
60
+ - 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
61
+ - **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
62
+ - **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
63
+ - **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
64
+ - 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
65
+ - **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
66
+ - **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
67
+ - 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
68
+ - **推理加速**:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
69
+ - **模型评测**:以EvalScope作为评测后端,支持100+评测数据集对纯��本和多模态模型进行评测。
70
+ - **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
71
+
72
+ ## 🎉 新闻
73
+ - 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
74
+ - 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
75
+ - 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/train_multi_round.sh)。
76
+ - 🎁 2025.03.16: 支持了Megatron的并行技术进行训练,请查看[Megatron-SWIFT训练文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html)。
77
+ - 🎁 2025.03.15: 支持纯文本和多模态模型的embedding模型的微调,请查看[训练脚本](examples/train/embedding)。
78
+ - 🎁 2025.03.05: 支持GRPO的hybrid模式,4GPU(4*80G)训练72B模型的脚本参考[这里](examples/train/grpo/internal/train_72b_4gpu.sh)。同时支持vllm的tensor并行,训练脚本参考[这里](examples/train/grpo/internal/multi_gpu_mp_colocate.sh)。
79
+ - 🎁 2025.02.21: GRPO算法支持使用LMDeploy,训练脚本参考[这里](examples/train/grpo/internal/full_lmdeploy.sh)。此外测试了GRPO算法的性能,使用一些tricks使训练速度提高到300%。WanDB表格请查看[这里](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz)。
80
+ - 🎁 2025.02.21: 支持`swift sample`命令。强化微调脚本参考[这里](docs/source/Instruction/强化微调.md),大模型API蒸馏采样脚本参考[这里](examples/sampler/distill/distill.sh)。
81
+ - 🔥 2025.02.12: 支持GRPO (Group Relative Policy Optimization) 训练算法,文档参考[这里](docs/source/Instruction/GRPO.md)。
82
+ - 🎁 2024.12.04: **ms-swift3.0**大版本更新。请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
83
+ <details><summary>更多</summary>
84
+
85
+ - 🎉 2024.08.12: ms-swift论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
86
+ - 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
87
+ - 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
88
+ - 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。
89
+ - 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。
90
+ </details>
91
+
92
+ ## 🛠️ 安装
93
+ 使用pip进行安装:
94
+ ```shell
95
+ pip install ms-swift -U
96
+ ```
97
+
98
+ 从源代码安装:
99
+ ```shell
100
+ # pip install git+https://github.com/modelscope/ms-swift.git
101
+
102
+ git clone https://github.com/modelscope/ms-swift.git
103
+ cd ms-swift
104
+ pip install -e .
105
+ ```
106
+
107
+ 运行环境:
108
+
109
+ | | 范围 | 推荐 | 备注 |
110
+ | ------ |--------------| ---- | --|
111
+ | python | >=3.9 | 3.10 ||
112
+ | cuda | | cuda12 |使用cpu、npu、mps则无需安装|
113
+ | torch | >=2.0 | ||
114
+ | transformers | >=4.33 | 4.51 ||
115
+ | modelscope | >=1.23 | ||
116
+ | peft | >=0.11,<0.16 | ||
117
+ | trl | >=0.13,<0.18 | 0.17 |RLHF|
118
+ | deepspeed | >=0.14 | 0.14.5 |训练|
119
+ | vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
120
+ | lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
121
+ | evalscope | >=0.11 | |评测|
122
+
123
+ 更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
124
+
125
+
126
+ ## 🚀 快速开始
127
+
128
+ **10分钟**在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
129
+
130
+ ### 命令行
131
+ ```shell
132
+ # 22GB
133
+ CUDA_VISIBLE_DEVICES=0 \
134
+ swift sft \
135
+ --model Qwen/Qwen2.5-7B-Instruct \
136
+ --train_type lora \
137
+ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
138
+ 'AI-ModelScope/alpaca-gpt4-data-en#500' \
139
+ 'swift/self-cognition#500' \
140
+ --torch_dtype bfloat16 \
141
+ --num_train_epochs 1 \
142
+ --per_device_train_batch_size 1 \
143
+ --per_device_eval_batch_size 1 \
144
+ --learning_rate 1e-4 \
145
+ --lora_rank 8 \
146
+ --lora_alpha 32 \
147
+ --target_modules all-linear \
148
+ --gradient_accumulation_steps 16 \
149
+ --eval_steps 50 \
150
+ --save_steps 50 \
151
+ --save_total_limit 2 \
152
+ --logging_steps 5 \
153
+ --max_length 2048 \
154
+ --output_dir output \
155
+ --system 'You are a helpful assistant.' \
156
+ --warmup_ratio 0.05 \
157
+ --dataloader_num_workers 4 \
158
+ --model_author swift \
159
+ --model_name swift-robot
160
+ ```
161
+
162
+ 小贴士:
163
+ - 如果要使用自定义数据集进行训练,你可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html)组织数据集格式,并指定`--dataset <dataset_path>`。
164
+ - `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
165
+ - 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
166
+ - 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
167
+
168
+ 训练完成后,使用以下命令对训练后的权重进行推理:
169
+ - 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
170
+
171
+ ```shell
172
+ # 使用交互式命令行进行推理
173
+ CUDA_VISIBLE_DEVICES=0 \
174
+ swift infer \
175
+ --adapters output/vx-xxx/checkpoint-xxx \
176
+ --stream true \
177
+ --temperature 0 \
178
+ --max_new_tokens 2048
179
+
180
+ # merge-lora并使用vLLM进行推理加速
181
+ CUDA_VISIBLE_DEVICES=0 \
182
+ swift infer \
183
+ --adapters output/vx-xxx/checkpoint-xxx \
184
+ --stream true \
185
+ --merge_lora true \
186
+ --infer_backend vllm \
187
+ --max_model_len 8192 \
188
+ --temperature 0 \
189
+ --max_new_tokens 2048
190
+ ```
191
+
192
+ 最后,使用以下命令将模型推送到ModelScope:
193
+ ```shell
194
+ CUDA_VISIBLE_DEVICES=0 \
195
+ swift export \
196
+ --adapters output/vx-xxx/checkpoint-xxx \
197
+ --push_to_hub true \
198
+ --hub_model_id '<your-model-id>' \
199
+ --hub_token '<your-sdk-token>' \
200
+ --use_hf false
201
+ ```
202
+
203
+ ### Web-UI
204
+
205
+ Web-UI是基于gradio界面技术的**零门槛**训练、部署界面方案,具体可以查看[这里](https://swift.readthedocs.io/zh-cn/latest/GetStarted/Web-UI.html)。
206
+
207
+ ```shell
208
+ swift web-ui
209
+ ```
210
+ ![image.png](./docs/resources/web-ui.jpg)
211
+
212
+ ### 使用Python
213
+ ms-swift也支持使用python的方式进行训练和推理。下面给出训练和推理的**伪代码**,具体可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)。
214
+
215
+ 训练:
216
+ ```python
217
+ # 获取模型和template,并加入可训练的LoRA模块
218
+ model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
219
+ template = get_template(model.model_meta.template, tokenizer, ...)
220
+ model = Swift.prepare_model(model, lora_config)
221
+
222
+ # 下载并载入数据集,并将文本encode成tokens
223
+ train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
224
+ train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
225
+ val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
226
+
227
+ # 进行训练
228
+ trainer = Seq2SeqTrainer(
229
+ model=model,
230
+ args=training_args,
231
+ data_collator=template.data_collator,
232
+ train_dataset=train_dataset,
233
+ eval_dataset=val_dataset,
234
+ template=template,
235
+ )
236
+ trainer.train()
237
+ ```
238
+
239
+ 推理:
240
+ ```python
241
+ # 使用原生pytorch引擎进行推理
242
+ engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
243
+ infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
244
+ request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
245
+
246
+ resp_list = engine.infer([infer_request], request_config)
247
+ print(f'response: {resp_list[0].choices[0].message.content}')
248
+ ```
249
+
250
+ ## ✨ 如何使用
251
+
252
+ 这里给出使用ms-swift进行训练到部署到最简示例,具体可以查看[examples](https://github.com/modelscope/ms-swift/tree/main/examples)。
253
+
254
+ - 若想使用其他模型或者数据集(含多模态模型和数据集),你只需要修改`--model`指定对应模型的id或者path,修改`--dataset`指定对应数据集的id或者path即可。
255
+ - 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
256
+
257
+ | 常用链接 |
258
+ | ------ |
259
+ | [🔥命令行参数](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html) |
260
+ | [支持的模型和数据集](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
261
+ | [自定义模型](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B.html), [🔥自定义数据集](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
262
+ | [大模型教程](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
263
+
264
+ ### 训练
265
+ 支持的训练方法:
266
+
267
+ | 方法 | 全参数 | LoRA | QLoRA | Deepspeed | 多机 | 多模态 |
268
+ | ------ | ------ |---------------------------------------------------------------------------------------------| ----- | ------ | ------ |----------------------------------------------------------------------------------------------|
269
+ | 预训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
270
+ | 指令监督微调 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
271
+ | DPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
272
+ | GRPO训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
273
+ | 奖励模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
274
+ | PPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
275
+ | KTO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
276
+ | CPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
277
+ | SimPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
278
+ | ORPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
279
+ | 分类模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
280
+ | Embedding模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
281
+
282
+
283
+ 预训练:
284
+ ```shell
285
+ # 8*A100
286
+ NPROC_PER_NODE=8 \
287
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
288
+ swift pt \
289
+ --model Qwen/Qwen2.5-7B \
290
+ --dataset swift/chinese-c4 \
291
+ --streaming true \
292
+ --train_type full \
293
+ --deepspeed zero2 \
294
+ --output_dir output \
295
+ --max_steps 10000 \
296
+ ...
297
+ ```
298
+
299
+ 微调:
300
+ ```shell
301
+ CUDA_VISIBLE_DEVICES=0 swift sft \
302
+ --model Qwen/Qwen2.5-7B-Instruct \
303
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
304
+ --train_type lora \
305
+ --output_dir output \
306
+ ...
307
+ ```
308
+
309
+ RLHF:
310
+ ```shell
311
+ CUDA_VISIBLE_DEVICES=0 swift rlhf \
312
+ --rlhf_type dpo \
313
+ --model Qwen/Qwen2.5-7B-Instruct \
314
+ --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
315
+ --train_type lora \
316
+ --output_dir output \
317
+ ...
318
+ ```
319
+
320
+
321
+ ### 推理
322
+ ```shell
323
+ CUDA_VISIBLE_DEVICES=0 swift infer \
324
+ --model Qwen/Qwen2.5-7B-Instruct \
325
+ --stream true \
326
+ --infer_backend pt \
327
+ --max_new_tokens 2048
328
+
329
+ # LoRA
330
+ CUDA_VISIBLE_DEVICES=0 swift infer \
331
+ --model Qwen/Qwen2.5-7B-Instruct \
332
+ --adapters swift/test_lora \
333
+ --stream true \
334
+ --infer_backend pt \
335
+ --temperature 0 \
336
+ --max_new_tokens 2048
337
+ ```
338
+
339
+ ### 界面推理
340
+ ```shell
341
+ CUDA_VISIBLE_DEVICES=0 swift app \
342
+ --model Qwen/Qwen2.5-7B-Instruct \
343
+ --stream true \
344
+ --infer_backend pt \
345
+ --max_new_tokens 2048 \
346
+ --lang zh
347
+ ```
348
+
349
+ ### 部署
350
+ ```shell
351
+ CUDA_VISIBLE_DEVICES=0 swift deploy \
352
+ --model Qwen/Qwen2.5-7B-Instruct \
353
+ --infer_backend vllm
354
+ ```
355
+
356
+ ### 采样
357
+ ```shell
358
+ CUDA_VISIBLE_DEVICES=0 swift sample \
359
+ --model LLM-Research/Meta-Llama-3.1-8B-Instruct \
360
+ --sampler_engine pt \
361
+ --num_return_sequences 5 \
362
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
363
+ ```
364
+
365
+ ### 评测
366
+ ```shell
367
+ CUDA_VISIBLE_DEVICES=0 swift eval \
368
+ --model Qwen/Qwen2.5-7B-Instruct \
369
+ --infer_backend lmdeploy \
370
+ --eval_backend OpenCompass \
371
+ --eval_dataset ARC_c
372
+ ```
373
+
374
+ ### 量化
375
+ ```shell
376
+ CUDA_VISIBLE_DEVICES=0 swift export \
377
+ --model Qwen/Qwen2.5-7B-Instruct \
378
+ --quant_bits 4 --quant_method awq \
379
+ --dataset AI-ModelScope/alpaca-gpt4-data-zh \
380
+ --output_dir Qwen2.5-7B-Instruct-AWQ
381
+ ```
382
+
383
+ ### 推送模型
384
+ ```shell
385
+ swift export \
386
+ --model <model-path> \
387
+ --push_to_hub true \
388
+ --hub_model_id '<model-id>' \
389
+ --hub_token '<sdk-token>'
390
+ ```
391
+
392
+
393
+ ## 🏛 License
394
+
395
+ 本框架使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)进行许可。模型和数据集请查看原资源页面并遵守对应License。
396
+
397
+ ## 📎 引用
398
+
399
+ ```bibtex
400
+ @misc{zhao2024swiftascalablelightweightinfrastructure,
401
+ title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
402
+ author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
403
+ year={2024},
404
+ eprint={2408.05517},
405
+ archivePrefix={arXiv},
406
+ primaryClass={cs.CL},
407
+ url={https://arxiv.org/abs/2408.05517},
408
+ }
409
+ ```
410
+
411
+ ## Star History
412
+
413
+ [![Star History Chart](https://api.star-history.com/svg?repos=modelscope/swift&type=Date)](https://star-history.com/#modelscope/ms-swift&Date)
.ipynb_checkpoints/VLLM-checkpoint.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
2
+ --model /root/autodl-tmp/output_7B_FULL_cotSFT/v0-20250621-230827/Qwen2.5-Omni-7B \
3
+ --tokenizer /root/autodl-tmp/output_7B_FULL_cotSFT/v0-20250621-230827/Qwen2.5-Omni-7B \
4
+ --dtype bfloat16 \
5
+ --host 127.0.0.1 \
6
+ --port 8000 \
7
+ --gpu-memory-utilization 0.9
.ipynb_checkpoints/clean_transcripts-checkpoint.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import List, Dict, Tuple
4
+
5
+ def parse_timestamp(timestamp: str) -> Tuple[int, int]:
6
+ """Convert timestamp string like '00:15' to seconds."""
7
+ minutes, seconds = map(int, timestamp.split(':'))
8
+ return minutes * 60 + seconds
9
+
10
+ def extract_time_and_speaker(line: str) -> Tuple[Tuple[int, int], str]:
11
+ """Extract time range and speaker from a line."""
12
+ # Extract time range
13
+ time_match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\] (Speaker [A-Z]):', line)
14
+ if not time_match:
15
+ return None, None
16
+
17
+ start_time = parse_timestamp(time_match.group(1))
18
+ end_time = parse_timestamp(time_match.group(2))
19
+ speaker = time_match.group(3)
20
+
21
+ return (start_time, end_time), speaker
22
+
23
+ def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
24
+ """Check if two time ranges overlap."""
25
+ start1, end1 = range1
26
+ start2, end2 = range2
27
+ return not (end1 <= start2 or end2 <= start1)
28
+
29
+ def has_same_speaker_overlap(transcript: str) -> bool:
30
+ """Check if a transcript contains overlapping timestamps for the same speaker."""
31
+ lines = transcript.split('\n')
32
+ # Dictionary to store time ranges for each speaker
33
+ speaker_ranges = {}
34
+
35
+ for line in lines:
36
+ if not line.strip():
37
+ continue
38
+
39
+ time_range, speaker = extract_time_and_speaker(line)
40
+ if time_range is None or speaker is None:
41
+ continue
42
+
43
+ # Check for overlaps with existing ranges of the same speaker
44
+ if speaker in speaker_ranges:
45
+ for existing_range in speaker_ranges[speaker]:
46
+ if has_overlap(time_range, existing_range):
47
+ return True
48
+
49
+ speaker_ranges[speaker].append(time_range)
50
+ else:
51
+ speaker_ranges[speaker] = [time_range]
52
+
53
+ return False
54
+
55
+ def process_file(input_file: str, output_file: str, delete_file: str):
56
+ """Process the JSON file and separate entries with same-speaker overlapping timestamps."""
57
+ with open(input_file, 'r', encoding='utf-8') as f:
58
+ data = json.load(f)
59
+
60
+ if isinstance(data, dict):
61
+ data = [data]
62
+
63
+ cleaned_data = []
64
+ deleted_data = []
65
+ removed_count = 0
66
+
67
+ for entry in data:
68
+ if 'model_output' in entry:
69
+ if not has_same_speaker_overlap(entry['model_output']):
70
+ cleaned_data.append(entry)
71
+ else:
72
+ deleted_data.append(entry)
73
+ removed_count += 1
74
+ print(f"Removing entry with key: {entry.get('key', 'unknown')}")
75
+
76
+ # Save cleaned data
77
+ with open(output_file, 'w', encoding='utf-8') as f:
78
+ json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
79
+
80
+ # Save deleted data
81
+ with open(delete_file, 'w', encoding='utf-8') as f:
82
+ json.dump(deleted_data, f, ensure_ascii=False, indent=2)
83
+
84
+ print(f"\nProcessing Summary:")
85
+ print(f"Processed {len(data)} entries")
86
+ print(f"Removed {removed_count} entries with same-speaker overlapping timestamps")
87
+ print(f"Remaining entries: {len(cleaned_data)}")
88
+
89
+ if __name__ == '__main__':
90
+ input_file = 'silence_overlaps/transcriptions.json'
91
+ output_file = 'silence_overlaps/cleaned_transcriptions2.json'
92
+ delete_file = 'silence_overlaps/delete_transcript2.json'
93
+ process_file(input_file, output_file, delete_file)
94
+ print(f"\nCleaned transcriptions have been saved to {output_file}")
95
+ print(f"Deleted entries have been saved to {delete_file}")
.ipynb_checkpoints/compare_scores-checkpoint.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from collections import defaultdict
4
+
5
+ infer_result_path = '/root/autodl-tmp/output_7B_GRPO/v28-20250722-002940/checkpoint-870/infer_result/53_HH.jsonl'
6
+ test_path = '/root/autodl-tmp/ms-swift/all_audio_test_50.jsonl'
7
+ output_path = 'inference_comparison_result.json'
8
+
9
+ def extract_overall_score(response_text):
10
+ match = re.search(r'<overall score>(\d+)</overall score>', response_text)
11
+ if match:
12
+ return int(match.group(1))
13
+ return None
14
+
15
+ def main():
16
+ # 读取infer_result文件,建立audio到score的映射
17
+ infer_audio2score = {}
18
+ with open(infer_result_path, 'r', encoding='utf-8') as f:
19
+ for line in f:
20
+ data = json.loads(line)
21
+ score = extract_overall_score(data['response'])
22
+ audios = tuple(data.get('audios', []))
23
+ infer_audio2score[audios] = {
24
+ 'score': score,
25
+ 'raw_response': data['response']
26
+ }
27
+
28
+ # 读取test文件,建立audio到solution的映射
29
+ test_audio2solution = {}
30
+ with open(test_path, 'r', encoding='utf-8') as f:
31
+ for line in f:
32
+ data = json.loads(line)
33
+ solution = data['solution']
34
+ audios = tuple(data.get('audios', []))
35
+ test_audio2solution[audios] = solution
36
+
37
+ # 统计和收集错误样本 & 所有推理结果
38
+ stats_per_class = defaultdict(lambda: {'correct': 0, 'incorrect': 0})
39
+ incorrect_samples_solution1 = []
40
+ all_results = []
41
+
42
+ total = 0
43
+ correct = 0
44
+
45
+ for audios, solution in test_audio2solution.items():
46
+ infer_entry = infer_audio2score.get(audios, None)
47
+ infer_score = infer_entry['score'] if infer_entry else None
48
+ raw_response = infer_entry['raw_response'] if infer_entry else None
49
+ match = infer_score == solution
50
+
51
+ # 收集所有结果
52
+ all_results.append({
53
+ 'audios': audios,
54
+ 'gt_solution': solution,
55
+ 'predicted_score': infer_score,
56
+ 'match': match,
57
+ 'response': raw_response
58
+ })
59
+
60
+ if match:
61
+ correct += 1
62
+ stats_per_class[solution]['correct'] += 1
63
+ else:
64
+ stats_per_class[solution]['incorrect'] += 1
65
+ if solution == 1:
66
+ incorrect_samples_solution1.append({
67
+ 'audios': audios,
68
+ 'gt_solution': solution,
69
+ 'predicted_score': infer_score,
70
+ 'response': raw_response
71
+ })
72
+
73
+ total += 1
74
+
75
+ # 总体准确率
76
+ print(f'\nOverall Accuracy: {correct}/{total} = {correct/total:.2%}\n')
77
+
78
+ # 每类准确率
79
+ print("Per-Class Accuracy:")
80
+ for solution, stats in sorted(stats_per_class.items()):
81
+ total_class = stats['correct'] + stats['incorrect']
82
+ accuracy = stats['correct'] / total_class if total_class > 0 else 0.0
83
+ print(f'Class {solution}: Correct={stats["correct"]}, Incorrect={stats["incorrect"]}, Accuracy={accuracy:.2%}')
84
+
85
+ # 列出 solution=1 且预测错误的样本
86
+ print("\nIncorrect Samples for solution = 1:")
87
+ for sample in incorrect_samples_solution1:
88
+ print(json.dumps(sample, indent=2, ensure_ascii=False))
89
+
90
+ # 写入所有结果到 JSON 文件
91
+ with open(output_path, 'w', encoding='utf-8') as f:
92
+ json.dump(all_results, f, indent=2, ensure_ascii=False)
93
+ print(f"\nAll inference comparison results saved to: {output_path}")
94
+
95
+ if __name__ == '__main__':
96
+ main()
.ipynb_checkpoints/dialogue_length_ranges-checkpoint.png ADDED
.ipynb_checkpoints/infer-checkpoint.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3
+
4
+ from swift.llm import PtEngine, RequestConfig, safe_snapshot_download, get_model_tokenizer, get_template, InferRequest
5
+ import json
6
+ from transformers import AutoProcessor
7
+ from swift.tuners import Swift
8
+ last_model_checkpoint = '/root/autodl-tmp/output_7B_Lora_cotSFT/v2-20250613-111902/checkpoint-3'
9
+
10
+ # 模型
11
+ model_id_or_path = '/root/autodl-tmp/output_7B_Lora_allmission/v2-20250610-190504/checkpoint-1000-merged' # model_id or model_path
12
+ system = 'You are a helpful assistant.'
13
+ infer_backend = 'pt'
14
+
15
+ # 生成参数
16
+ max_new_tokens = 2048
17
+ temperature = 0
18
+ stream = False
19
+ template_type = None
20
+ default_system = system # None: 使用对应模型默认的default_system
21
+ # 初始化音频处理器
22
+ model, tokenizer = get_model_tokenizer(model_id_or_path)
23
+ # 初始化引擎
24
+ model = Swift.from_pretrained(model, last_model_checkpoint)
25
+ template_type = template_type or model.model_meta.template
26
+ template = get_template(template_type, tokenizer, default_system=default_system)
27
+ engine = PtEngine.from_model_template(model, template, max_batch_size=2)
28
+ request_config = RequestConfig(max_tokens=8192, temperature=0)
29
+
30
+ def load_test_data(json_file):
31
+ test_requests = []
32
+ with open(json_file, 'r', encoding='utf-8') as f:
33
+ for line in f:
34
+ data = json.loads(line.strip())
35
+ test_requests.append(InferRequest(
36
+ messages=data['messages'],
37
+ audios=data['audios']
38
+ ))
39
+ return test_requests
40
+
41
+ def main():
42
+ # 加载测试数据
43
+ test_file = 'dataset_allmissiontest.json'
44
+ infer_requests = load_test_data(test_file)
45
+ results = []
46
+ resp_list = engine.infer(infer_requests, request_config)
47
+ for i, resp in enumerate(resp_list):
48
+ assistant_content = next((msg['content'] for msg in infer_requests[i].messages if msg['role'] == 'assistant'), None)
49
+ result = {
50
+ "index": i,
51
+ "truth": assistant_content,
52
+ "response": resp.choices[0].message.content,
53
+ }
54
+ results.append(result)
55
+ print(f'truth{i}: {assistant_content}')
56
+ print(f'response{i}: {resp.choices[0].message.content}')
57
+ output_file = 'inference_results.json'
58
+ with open(output_file, 'w', encoding='utf-8') as f:
59
+ json.dump(results, f, ensure_ascii=False, indent=2)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ main()
GRPO_TRAIN.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/phimoe/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_phimoe import *
22
+ from .modeling_phimoe import *
23
+
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ import os
17
+ import re
18
+
19
+ import torch
20
+ from flax.traverse_util import flatten_dict
21
+ from t5x import checkpoints
22
+
23
+ from transformers import (
24
+ AutoTokenizer,
25
+ Pix2StructConfig,
26
+ Pix2StructForConditionalGeneration,
27
+ Pix2StructImageProcessor,
28
+ Pix2StructProcessor,
29
+ Pix2StructTextConfig,
30
+ Pix2StructVisionConfig,
31
+ )
32
+
33
+
34
+ def get_flax_param(t5x_checkpoint_path):
35
+ flax_params = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
36
+ flax_params = flatten_dict(flax_params)
37
+ return flax_params
38
+
39
+
40
+ def rename_and_convert_flax_params(flax_dict):
41
+ converted_dict = {}
42
+
43
+ CONVERSION_MAPPING = {
44
+ "token_embedder": "embeddings",
45
+ "encoder_norm": "layernorm",
46
+ "kernel": "weight",
47
+ ".out": ".output",
48
+ "scale": "weight",
49
+ "embedders_0.pos_embedding": "row_embedder.weight",
50
+ "embedders_1.pos_embedding": "column_embedder.weight",
51
+ }
52
+
53
+ DECODER_CONVERSION_MAPPING = {
54
+ "query": "attention.query",
55
+ "key": "attention.key",
56
+ "value": "attention.value",
57
+ "output.dense": "output",
58
+ "encoder_decoder_attention.o": "encoder_decoder_attention.attention.o",
59
+ "pre_self_attention_layer_norm": "self_attention.layer_norm",
60
+ "pre_cross_attention_layer_norm": "encoder_decoder_attention.layer_norm",
61
+ "mlp.": "mlp.DenseReluDense.",
62
+ "pre_mlp_layer_norm": "mlp.layer_norm",
63
+ "self_attention.o": "self_attention.attention.o",
64
+ "decoder.embeddings.embedding": "decoder.embed_tokens.weight",
65
+ "decoder.relpos_bias.rel_embedding": "decoder.layer.0.self_attention.attention.relative_attention_bias.weight",
66
+ "decoder.decoder_norm.weight": "decoder.final_layer_norm.weight",
67
+ "decoder.logits_dense.weight": "decoder.lm_head.weight",
68
+ }
69
+
70
+ for key in flax_dict.keys():
71
+ if "target" in key:
72
+ # remove the first prefix from the key
73
+ new_key = ".".join(key[1:])
74
+
75
+ # rename the key
76
+ for old, new in CONVERSION_MAPPING.items():
77
+ new_key = new_key.replace(old, new)
78
+
79
+ if "decoder" in new_key:
80
+ for old, new in DECODER_CONVERSION_MAPPING.items():
81
+ new_key = new_key.replace(old, new)
82
+
83
+ if "layers" in new_key and "decoder" not in new_key:
84
+ # use regex to replace the layer number
85
+ new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key)
86
+ new_key = new_key.replace("encoder", "encoder.encoder")
87
+
88
+ elif "layers" in new_key and "decoder" in new_key:
89
+ # use regex to replace the layer number
90
+ new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key)
91
+
92
+ converted_dict[new_key] = flax_dict[key]
93
+
94
+ converted_torch_dict = {}
95
+ # convert converted_dict into torch format
96
+ for key in converted_dict.keys():
97
+ if ("embed_tokens" not in key) and ("embedder" not in key):
98
+ converted_torch_dict[key] = torch.from_numpy(converted_dict[key].T)
99
+ else:
100
+ converted_torch_dict[key] = torch.from_numpy(converted_dict[key])
101
+
102
+ return converted_torch_dict
103
+
104
+
105
+ def convert_pix2struct_original_pytorch_checkpoint_to_hf(
106
+ t5x_checkpoint_path, pytorch_dump_folder_path, use_large=False, is_vqa=False
107
+ ):
108
+ flax_params = get_flax_param(t5x_checkpoint_path)
109
+
110
+ if not use_large:
111
+ encoder_config = Pix2StructVisionConfig()
112
+ decoder_config = Pix2StructTextConfig()
113
+ else:
114
+ encoder_config = Pix2StructVisionConfig(
115
+ hidden_size=1536, d_ff=3968, num_attention_heads=24, num_hidden_layers=18
116
+ )
117
+ decoder_config = Pix2StructTextConfig(hidden_size=1536, d_ff=3968, num_heads=24, num_layers=18)
118
+ config = Pix2StructConfig(
119
+ vision_config=encoder_config.to_dict(), text_config=decoder_config.to_dict(), is_vqa=is_vqa
120
+ )
121
+
122
+ model = Pix2StructForConditionalGeneration(config)
123
+
124
+ torch_params = rename_and_convert_flax_params(flax_params)
125
+ model.load_state_dict(torch_params)
126
+
127
+ tok = AutoTokenizer.from_pretrained("ybelkada/test-pix2struct-tokenizer")
128
+ image_processor = Pix2StructImageProcessor()
129
+ processor = Pix2StructProcessor(image_processor=image_processor, tokenizer=tok)
130
+
131
+ if use_large:
132
+ processor.image_processor.max_patches = 4096
133
+
134
+ processor.image_processor.is_vqa = True
135
+
136
+ # mkdir if needed
137
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
138
+
139
+ model.save_pretrained(pytorch_dump_folder_path)
140
+ processor.save_pretrained(pytorch_dump_folder_path)
141
+
142
+ print("Model saved in {}".format(pytorch_dump_folder_path))
143
+
144
+
145
+ if __name__ == "__main__":
146
+ parser = argparse.ArgumentParser()
147
+ parser.add_argument("--t5x_checkpoint_path", default=None, type=str, help="Path to the original T5x checkpoint.")
148
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
149
+ parser.add_argument("--use_large", action="store_true", help="Use large model.")
150
+ parser.add_argument("--is_vqa", action="store_true", help="Use large model.")
151
+ args = parser.parse_args()
152
+
153
+ convert_pix2struct_original_pytorch_checkpoint_to_hf(
154
+ args.t5x_checkpoint_path, args.pytorch_dump_folder_path, args.use_large
155
+ )
docs/transformers/build/lib/transformers/models/pix2struct/image_processing_pix2struct.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Pix2Struct."""
16
+
17
+ import io
18
+ import math
19
+ from typing import Dict, Optional, Union
20
+
21
+ import numpy as np
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature
25
+ from ...image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image
26
+ from ...image_utils import (
27
+ ChannelDimension,
28
+ ImageInput,
29
+ get_image_size,
30
+ infer_channel_dimension_format,
31
+ make_list_of_images,
32
+ to_numpy_array,
33
+ valid_images,
34
+ )
35
+ from ...utils import TensorType, is_torch_available, is_vision_available, logging
36
+ from ...utils.import_utils import requires_backends
37
+
38
+
39
+ if is_vision_available():
40
+ import textwrap
41
+
42
+ from PIL import Image, ImageDraw, ImageFont
43
+
44
+ if is_torch_available():
45
+ import torch
46
+
47
+ logger = logging.get_logger(__name__)
48
+ DEFAULT_FONT_PATH = "ybelkada/fonts"
49
+
50
+
51
+ # adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
52
+ def torch_extract_patches(image_tensor, patch_height, patch_width):
53
+ """
54
+ Utiliy function to extract patches from a given image tensor. Returns a tensor of shape (1, `patch_height`,
55
+ `patch_width`, `num_channels`x `patch_height` x `patch_width`)
56
+
57
+ Args:
58
+ image_tensor (torch.Tensor):
59
+ The image tensor to extract patches from.
60
+ patch_height (int):
61
+ The height of the patches to extract.
62
+ patch_width (int):
63
+ The width of the patches to extract.
64
+ """
65
+ requires_backends(torch_extract_patches, ["torch"])
66
+
67
+ image_tensor = image_tensor.unsqueeze(0)
68
+ patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
69
+ patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1)
70
+ patches = patches.permute(0, 4, 2, 3, 1).reshape(
71
+ image_tensor.size(2) // patch_height,
72
+ image_tensor.size(3) // patch_width,
73
+ image_tensor.size(1) * patch_height * patch_width,
74
+ )
75
+ return patches.unsqueeze(0)
76
+
77
+
78
+ # Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L106
79
+ def render_text(
80
+ text: str,
81
+ text_size: int = 36,
82
+ text_color: str = "black",
83
+ background_color: str = "white",
84
+ left_padding: int = 5,
85
+ right_padding: int = 5,
86
+ top_padding: int = 5,
87
+ bottom_padding: int = 5,
88
+ font_bytes: Optional[bytes] = None,
89
+ font_path: Optional[str] = None,
90
+ ) -> Image.Image:
91
+ """
92
+ Render text. This script is entirely adapted from the original script that can be found here:
93
+ https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py
94
+
95
+ Args:
96
+ text (`str`, *optional*, defaults to ):
97
+ Text to render.
98
+ text_size (`int`, *optional*, defaults to 36):
99
+ Size of the text.
100
+ text_color (`str`, *optional*, defaults to `"black"`):
101
+ Color of the text.
102
+ background_color (`str`, *optional*, defaults to `"white"`):
103
+ Color of the background.
104
+ left_padding (`int`, *optional*, defaults to 5):
105
+ Padding on the left.
106
+ right_padding (`int`, *optional*, defaults to 5):
107
+ Padding on the right.
108
+ top_padding (`int`, *optional*, defaults to 5):
109
+ Padding on the top.
110
+ bottom_padding (`int`, *optional*, defaults to 5):
111
+ Padding on the bottom.
112
+ font_bytes (`bytes`, *optional*):
113
+ Bytes of the font to use. If `None`, the default font will be used.
114
+ font_path (`str`, *optional*):
115
+ Path to the font to use. If `None`, the default font will be used.
116
+ """
117
+ requires_backends(render_text, "vision")
118
+ # Add new lines so that each line is no more than 80 characters.
119
+
120
+ wrapper = textwrap.TextWrapper(width=80)
121
+ lines = wrapper.wrap(text=text)
122
+ wrapped_text = "\n".join(lines)
123
+
124
+ if font_bytes is not None and font_path is None:
125
+ font = io.BytesIO(font_bytes)
126
+ elif font_path is not None:
127
+ font = font_path
128
+ else:
129
+ font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
130
+ font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
131
+
132
+ # Use a temporary canvas to determine the width and height in pixels when
133
+ # rendering the text.
134
+ temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
135
+ _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
136
+
137
+ # Create the actual image with a bit of padding around the text.
138
+ image_width = text_width + left_padding + right_padding
139
+ image_height = text_height + top_padding + bottom_padding
140
+ image = Image.new("RGB", (image_width, image_height), background_color)
141
+ draw = ImageDraw.Draw(image)
142
+ draw.text(xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font)
143
+ return image
144
+
145
+
146
+ # Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
147
+ def render_header(
148
+ image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs
149
+ ):
150
+ """
151
+ Renders the input text as a header on the input image.
152
+
153
+ Args:
154
+ image (`np.ndarray`):
155
+ The image to render the header on.
156
+ header (`str`):
157
+ The header text.
158
+ data_format (`Union[ChannelDimension, str]`, *optional*):
159
+ The data format of the image. Can be either "ChannelDimension.channels_first" or
160
+ "ChannelDimension.channels_last".
161
+
162
+ Returns:
163
+ `np.ndarray`: The image with the header rendered.
164
+ """
165
+ requires_backends(render_header, "vision")
166
+
167
+ # Convert to PIL image if necessary
168
+ image = to_pil_image(image, input_data_format=input_data_format)
169
+
170
+ header_image = render_text(header, **kwargs)
171
+ new_width = max(header_image.width, image.width)
172
+
173
+ new_height = int(image.height * (new_width / image.width))
174
+ new_header_height = int(header_image.height * (new_width / header_image.width))
175
+
176
+ new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
177
+ new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
178
+ new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
179
+
180
+ # Convert back to the original framework if necessary
181
+ new_image = to_numpy_array(new_image)
182
+
183
+ if infer_channel_dimension_format(new_image) == ChannelDimension.LAST:
184
+ new_image = to_channel_dimension_format(new_image, ChannelDimension.LAST)
185
+
186
+ return new_image
187
+
188
+
189
+ class Pix2StructImageProcessor(BaseImageProcessor):
190
+ r"""
191
+ Constructs a Pix2Struct image processor.
192
+
193
+ Args:
194
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
195
+ Whether to convert the image to RGB.
196
+ do_normalize (`bool`, *optional*, defaults to `True`):
197
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
198
+ method. According to Pix2Struct paper and code, the image is normalized with its own mean and standard
199
+ deviation.
200
+ patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 16, "width": 16}`):
201
+ The patch size to use for the image. According to Pix2Struct paper and code, the patch size is 16x16.
202
+ max_patches (`int`, *optional*, defaults to 2048):
203
+ The maximum number of patches to extract from the image as per the [Pix2Struct
204
+ paper](https://arxiv.org/pdf/2210.03347.pdf).
205
+ is_vqa (`bool`, *optional*, defaults to `False`):
206
+ Whether or not the image processor is for the VQA task. If `True` and `header_text` is passed in, text is
207
+ rendered onto the input images.
208
+ """
209
+
210
+ model_input_names = ["flattened_patches"]
211
+
212
+ def __init__(
213
+ self,
214
+ do_convert_rgb: bool = True,
215
+ do_normalize: bool = True,
216
+ patch_size: Dict[str, int] = None,
217
+ max_patches: int = 2048,
218
+ is_vqa: bool = False,
219
+ **kwargs,
220
+ ) -> None:
221
+ super().__init__(**kwargs)
222
+ self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
223
+ self.do_normalize = do_normalize
224
+ self.do_convert_rgb = do_convert_rgb
225
+ self.max_patches = max_patches
226
+ self.is_vqa = is_vqa
227
+
228
+ def extract_flattened_patches(
229
+ self,
230
+ image: np.ndarray,
231
+ max_patches: int,
232
+ patch_size: dict,
233
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
234
+ **kwargs,
235
+ ) -> np.ndarray:
236
+ """
237
+ Extract flattened patches from an image.
238
+
239
+ Args:
240
+ image (`np.ndarray`):
241
+ Image to extract flattened patches from.
242
+ max_patches (`int`):
243
+ Maximum number of patches to extract.
244
+ patch_size (`dict`):
245
+ Dictionary containing the patch height and width.
246
+
247
+ Returns:
248
+ result (`np.ndarray`):
249
+ A sequence of `max_patches` flattened patches.
250
+ """
251
+ requires_backends(self.extract_flattened_patches, "torch")
252
+
253
+ # convert to torch
254
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
255
+ image = torch.from_numpy(image)
256
+
257
+ patch_height, patch_width = patch_size["height"], patch_size["width"]
258
+ image_height, image_width = get_image_size(image, ChannelDimension.FIRST)
259
+
260
+ # maximize scale s.t.
261
+ scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
262
+ num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1)
263
+ num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1)
264
+ resized_height = max(num_feasible_rows * patch_height, 1)
265
+ resized_width = max(num_feasible_cols * patch_width, 1)
266
+
267
+ image = torch.nn.functional.interpolate(
268
+ image.unsqueeze(0),
269
+ size=(resized_height, resized_width),
270
+ mode="bilinear",
271
+ align_corners=False,
272
+ antialias=True,
273
+ ).squeeze(0)
274
+
275
+ # [1, rows, columns, patch_height * patch_width * image_channels]
276
+ patches = torch_extract_patches(image, patch_height, patch_width)
277
+
278
+ patches_shape = patches.shape
279
+ rows = patches_shape[1]
280
+ columns = patches_shape[2]
281
+ depth = patches_shape[3]
282
+
283
+ # [rows * columns, patch_height * patch_width * image_channels]
284
+ patches = patches.reshape([rows * columns, depth])
285
+
286
+ # [rows * columns, 1]
287
+ row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1])
288
+ col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1])
289
+
290
+ # Offset by 1 so the ids do not contain zeros, which represent padding.
291
+ row_ids += 1
292
+ col_ids += 1
293
+
294
+ # Prepare additional patch features.
295
+ # [rows * columns, 1]
296
+ row_ids = row_ids.to(torch.float32)
297
+ col_ids = col_ids.to(torch.float32)
298
+
299
+ # [rows * columns, 2 + patch_height * patch_width * image_channels]
300
+ result = torch.cat([row_ids, col_ids, patches], -1)
301
+
302
+ # [max_patches, 2 + patch_height * patch_width * image_channels]
303
+ result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float()
304
+
305
+ result = to_numpy_array(result)
306
+
307
+ return result
308
+
309
+ def normalize(
310
+ self,
311
+ image: np.ndarray,
312
+ data_format: Optional[Union[str, ChannelDimension]] = None,
313
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
314
+ **kwargs,
315
+ ) -> np.ndarray:
316
+ """
317
+ Normalize an image. image = (image - image_mean) / image_std.
318
+
319
+ The image std is to mimic the tensorflow implementation of the `per_image_standardization`:
320
+ https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization
321
+
322
+ Args:
323
+ image (`np.ndarray`):
324
+ Image to normalize.
325
+ data_format (`str` or `ChannelDimension`, *optional*):
326
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
327
+ image is used.
328
+ input_data_format (`str` or `ChannelDimension`, *optional*):
329
+ The channel dimension format of the input image. If not provided, it will be inferred.
330
+ """
331
+ if image.dtype == np.uint8:
332
+ image = image.astype(np.float32)
333
+
334
+ # take mean across the whole `image`
335
+ mean = np.mean(image)
336
+ std = np.std(image)
337
+ adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape)))
338
+
339
+ return normalize(
340
+ image,
341
+ mean=mean,
342
+ std=adjusted_stddev,
343
+ data_format=data_format,
344
+ input_data_format=input_data_format,
345
+ **kwargs,
346
+ )
347
+
348
+ def preprocess(
349
+ self,
350
+ images: ImageInput,
351
+ header_text: Optional[str] = None,
352
+ do_convert_rgb: Optional[bool] = None,
353
+ do_normalize: Optional[bool] = None,
354
+ max_patches: Optional[int] = None,
355
+ patch_size: Optional[Dict[str, int]] = None,
356
+ return_tensors: Optional[Union[str, TensorType]] = None,
357
+ data_format: ChannelDimension = ChannelDimension.FIRST,
358
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
359
+ **kwargs,
360
+ ) -> ImageInput:
361
+ """
362
+ Preprocess an image or batch of images. The processor first computes the maximum possible number of
363
+ aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the
364
+ image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the
365
+ images are standardized following the tensorflow implementation of `per_image_standardization`
366
+ (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization).
367
+
368
+
369
+ Args:
370
+ images (`ImageInput`):
371
+ Image to preprocess. Expects a single or batch of images.
372
+ header_text (`Union[List[str], str]`, *optional*):
373
+ Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`.
374
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
375
+ Whether to convert the image to RGB.
376
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
377
+ Whether to normalize the image.
378
+ max_patches (`int`, *optional*, defaults to `self.max_patches`):
379
+ Maximum number of patches to extract.
380
+ patch_size (`dict`, *optional*, defaults to `self.patch_size`):
381
+ Dictionary containing the patch height and width.
382
+ return_tensors (`str` or `TensorType`, *optional*):
383
+ The type of tensors to return. Can be one of:
384
+ - Unset: Return a list of `np.ndarray`.
385
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
386
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
387
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
388
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
389
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
390
+ The channel dimension format for the output image. Can be one of:
391
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
392
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
393
+ - Unset: Use the channel dimension format of the input image.
394
+ input_data_format (`ChannelDimension` or `str`, *optional*):
395
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
396
+ from the input image. Can be one of:
397
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
398
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
399
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
400
+ """
401
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
402
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
403
+ patch_size = patch_size if patch_size is not None else self.patch_size
404
+ max_patches = max_patches if max_patches is not None else self.max_patches
405
+ is_vqa = self.is_vqa
406
+
407
+ if kwargs.get("data_format", None) is not None:
408
+ raise ValueError("data_format is not an accepted input as the outputs are ")
409
+
410
+ images = make_list_of_images(images)
411
+
412
+ if not valid_images(images):
413
+ raise ValueError(
414
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
415
+ "torch.Tensor, tf.Tensor or jax.ndarray."
416
+ )
417
+
418
+ # PIL RGBA images are converted to RGB
419
+ if do_convert_rgb:
420
+ images = [convert_to_rgb(image) for image in images]
421
+
422
+ # All transformations expect numpy arrays.
423
+ images = [to_numpy_array(image) for image in images]
424
+
425
+ if input_data_format is None:
426
+ # We assume that all images have the same channel dimension format.
427
+ input_data_format = infer_channel_dimension_format(images[0])
428
+
429
+ if is_vqa:
430
+ if header_text is None:
431
+ raise ValueError("A header text must be provided for VQA models.")
432
+ font_bytes = kwargs.pop("font_bytes", None)
433
+ font_path = kwargs.pop("font_path", None)
434
+
435
+ if isinstance(header_text, str):
436
+ header_text = [header_text] * len(images)
437
+
438
+ images = [
439
+ render_header(image, header_text[i], font_bytes=font_bytes, font_path=font_path)
440
+ for i, image in enumerate(images)
441
+ ]
442
+
443
+ if do_normalize:
444
+ images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
445
+
446
+ # convert to torch tensor and permute
447
+ images = [
448
+ self.extract_flattened_patches(
449
+ image=image, max_patches=max_patches, patch_size=patch_size, input_data_format=input_data_format
450
+ )
451
+ for image in images
452
+ ]
453
+
454
+ # create attention mask in numpy
455
+ attention_masks = [(image.sum(axis=-1) != 0).astype(np.float32) for image in images]
456
+
457
+ encoded_outputs = BatchFeature(
458
+ data={"flattened_patches": images, "attention_mask": attention_masks}, tensor_type=return_tensors
459
+ )
460
+
461
+ return encoded_outputs
462
+
463
+
464
+ __all__ = ["Pix2StructImageProcessor"]
docs/transformers/build/lib/transformers/models/pixtral/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_pixtral import *
22
+ from .image_processing_pixtral import *
23
+ from .image_processing_pixtral_fast import *
24
+ from .modeling_pixtral import *
25
+ from .processing_pixtral import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/pixtral/configuration_pixtral.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Pixtral model configuration"""
15
+
16
+ from ...configuration_utils import PretrainedConfig
17
+ from ...utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class PixtralVisionConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an
26
+ Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
27
+ with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B.
28
+
29
+ e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ hidden_size (`int`, *optional*, defaults to 1024):
36
+ Dimension of the hidden representations.
37
+ intermediate_size (`int`, *optional*, defaults to 4096):
38
+ Dimension of the MLP representations.
39
+ num_hidden_layers (`int`, *optional*, defaults to 24):
40
+ Number of hidden layers in the Transformer encoder.
41
+ num_attention_heads (`int`, *optional*, defaults to 16):
42
+ Number of attention heads in the Transformer encoder.
43
+ num_channels (`int`, *optional*, defaults to 3):
44
+ Number of input channels in the input images.
45
+ image_size (`int`, *optional*, defaults to 1024):
46
+ Max dimension of the input images.
47
+ patch_size (`int`, *optional*, defaults to 16):
48
+ Size of the image patches.
49
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
50
+ Activation function used in the hidden layers.
51
+ attention_dropout (`float`, *optional*, defaults to 0.0):
52
+ Dropout probability for the attention layers.
53
+ rope_theta (`float`, *optional*, defaults to 10000.0):
54
+ The base period of the RoPE embeddings.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+
58
+ Example:
59
+
60
+ ```python
61
+ >>> from transformers import PixtralVisionModel, PixtralVisionConfig
62
+
63
+ >>> # Initializing a Pixtral-12B style configuration
64
+ >>> config = PixtralVisionConfig()
65
+
66
+ >>> # Initializing a model (with randomly initialized weights) from the configuration
67
+ >>> model = PixtralVisionModel(configuration)
68
+
69
+ >>> # Accessing the model configuration
70
+ >>> configuration = model.config
71
+ ```"""
72
+
73
+ model_type = "pixtral"
74
+
75
+ def __init__(
76
+ self,
77
+ hidden_size=1024,
78
+ intermediate_size=4096,
79
+ num_hidden_layers=24,
80
+ num_attention_heads=16,
81
+ num_channels=3,
82
+ image_size=1024,
83
+ patch_size=16,
84
+ hidden_act="gelu",
85
+ attention_dropout=0.0,
86
+ rope_theta=10000.0,
87
+ initializer_range=0.02,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(**kwargs)
91
+
92
+ self.hidden_size = hidden_size
93
+ self.intermediate_size = intermediate_size
94
+ self.num_hidden_layers = num_hidden_layers
95
+ self.num_attention_heads = num_attention_heads
96
+ self.num_channels = num_channels
97
+ self.patch_size = patch_size
98
+ self.image_size = image_size
99
+ self.attention_dropout = attention_dropout
100
+ self.hidden_act = hidden_act
101
+ self.rope_theta = rope_theta
102
+ self.head_dim = hidden_size // num_attention_heads
103
+ self.initializer_range = initializer_range
104
+
105
+
106
+ __all__ = ["PixtralVisionConfig"]
docs/transformers/build/lib/transformers/models/pixtral/image_processing_pixtral.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Pixtral."""
16
+
17
+ import math
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+
22
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
23
+ from ...image_transforms import (
24
+ pad,
25
+ resize,
26
+ to_channel_dimension_format,
27
+ )
28
+ from ...image_utils import (
29
+ ChannelDimension,
30
+ ImageInput,
31
+ PILImageResampling,
32
+ get_image_size,
33
+ infer_channel_dimension_format,
34
+ is_scaled_image,
35
+ make_list_of_images,
36
+ to_numpy_array,
37
+ valid_images,
38
+ validate_kwargs,
39
+ validate_preprocess_arguments,
40
+ )
41
+ from ...utils import TensorType, is_vision_available, logging
42
+ from ...utils.import_utils import requires_backends
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ if is_vision_available():
49
+ import PIL
50
+
51
+
52
+ # Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
53
+ def convert_to_rgb(image: ImageInput) -> ImageInput:
54
+ """
55
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
56
+ as is.
57
+ Args:
58
+ image (Image):
59
+ The image to convert.
60
+ """
61
+ requires_backends(convert_to_rgb, ["vision"])
62
+
63
+ if not isinstance(image, PIL.Image.Image):
64
+ return image
65
+
66
+ if image.mode == "RGB":
67
+ return image
68
+
69
+ # First we convert to RGBA to set background to white.
70
+ image = image.convert("RGBA")
71
+
72
+ # Create a new image with a white background.
73
+ new_image = PIL.Image.new("RGBA", image.size, "WHITE")
74
+ new_image.paste(image, (0, 0), image)
75
+ new_image = new_image.convert("RGB")
76
+ return new_image
77
+
78
+
79
+ def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int:
80
+ """
81
+ Calculate the number of image tokens given the image size and patch size.
82
+
83
+ Args:
84
+ image_size (`Tuple[int, int]`):
85
+ The size of the image as `(height, width)`.
86
+ patch_size (`Tuple[int, int]`):
87
+ The patch size as `(height, width)`.
88
+
89
+ Returns:
90
+ `int`: The number of image tokens.
91
+ """
92
+ height, width = image_size
93
+ patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
94
+ num_width_tokens = (width - 1) // patch_width + 1
95
+ num_height_tokens = (height - 1) // patch_height + 1
96
+ return num_height_tokens, num_width_tokens
97
+
98
+
99
+ def get_resize_output_image_size(
100
+ input_image: ImageInput,
101
+ size: Union[int, Tuple[int, int], List[int], Tuple[int]],
102
+ patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
103
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
104
+ ) -> tuple:
105
+ """
106
+ Find the target (height, width) dimension of the output image after resizing given the input image and the desired
107
+ size.
108
+
109
+ Args:
110
+ input_image (`ImageInput`):
111
+ The image to resize.
112
+ size (`int` or `Tuple[int, int]`):
113
+ Max image size an input image can be. Must be a dictionary with the key "longest_edge".
114
+ patch_size (`int` or `Tuple[int, int]`):
115
+ The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)`
116
+ will be used
117
+ input_data_format (`ChannelDimension`, *optional*):
118
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
119
+
120
+ Returns:
121
+ `tuple`: The target (height, width) dimension of the output image after resizing.
122
+ """
123
+ max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size)
124
+ patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
125
+ height, width = get_image_size(input_image, input_data_format)
126
+
127
+ ratio = max(height / max_height, width / max_width)
128
+
129
+ if ratio > 1:
130
+ # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
131
+ # Here we use floor to ensure the image is always smaller than the given "longest_edge"
132
+ height = int(math.floor(height / ratio))
133
+ width = int(math.floor(width / ratio))
134
+
135
+ num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
136
+ return num_height_tokens * patch_height, num_width_tokens * patch_width
137
+
138
+
139
+ class PixtralImageProcessor(BaseImageProcessor):
140
+ r"""
141
+ Constructs a Pixtral image processor.
142
+
143
+ Args:
144
+ do_resize (`bool`, *optional*, defaults to `True`):
145
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
146
+ `do_resize` in the `preprocess` method.
147
+ size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
148
+ Size of the maximum dimension of either the height or width dimension of the image. Used to control how
149
+ images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
150
+ patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
151
+ Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
152
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
153
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
154
+ do_rescale (`bool`, *optional*, defaults to `True`):
155
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
156
+ the `preprocess` method.
157
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
158
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
159
+ method.
160
+ do_normalize (`bool`, *optional*, defaults to `True`):
161
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
162
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
163
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
164
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
165
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
166
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
167
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
168
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
169
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
170
+ Whether to convert the image to RGB.
171
+ """
172
+
173
+ model_input_names = ["pixel_values"]
174
+
175
+ def __init__(
176
+ self,
177
+ do_resize: bool = True,
178
+ size: Dict[str, int] = None,
179
+ patch_size: Dict[str, int] = None,
180
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
181
+ do_rescale: bool = True,
182
+ rescale_factor: Union[int, float] = 1 / 255,
183
+ do_normalize: bool = True,
184
+ image_mean: Optional[Union[float, List[float]]] = None,
185
+ image_std: Optional[Union[float, List[float]]] = None,
186
+ do_convert_rgb: bool = True,
187
+ **kwargs,
188
+ ) -> None:
189
+ super().__init__(**kwargs)
190
+ size = size if size is not None else {"longest_edge": 1024}
191
+ patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
192
+ patch_size = get_size_dict(patch_size, default_to_square=True)
193
+
194
+ self.do_resize = do_resize
195
+ self.size = size
196
+ self.patch_size = patch_size
197
+ self.resample = resample
198
+ self.do_rescale = do_rescale
199
+ self.rescale_factor = rescale_factor
200
+ self.do_normalize = do_normalize
201
+ self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
202
+ self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
203
+ self.do_convert_rgb = do_convert_rgb
204
+ self._valid_processor_keys = [
205
+ "images",
206
+ "do_resize",
207
+ "size",
208
+ "patch_size",
209
+ "resample",
210
+ "do_rescale",
211
+ "rescale_factor",
212
+ "do_normalize",
213
+ "image_mean",
214
+ "image_std",
215
+ "do_convert_rgb",
216
+ "return_tensors",
217
+ "data_format",
218
+ "input_data_format",
219
+ ]
220
+
221
+ def resize(
222
+ self,
223
+ image: np.ndarray,
224
+ size: Dict[str, int],
225
+ patch_size: Dict[str, int],
226
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
227
+ data_format: Optional[Union[str, ChannelDimension]] = None,
228
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
229
+ **kwargs,
230
+ ) -> np.ndarray:
231
+ """
232
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
233
+ resized to keep the input aspect ratio.
234
+
235
+ Args:
236
+ image (`np.ndarray`):
237
+ Image to resize.
238
+ size (`Dict[str, int]`):
239
+ Dict containing the longest possible edge of the image.
240
+ patch_size (`Dict[str, int]`):
241
+ Patch size used to calculate the size of the output image.
242
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
243
+ Resampling filter to use when resiizing the image.
244
+ data_format (`str` or `ChannelDimension`, *optional*):
245
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
246
+ input_data_format (`ChannelDimension` or `str`, *optional*):
247
+ The channel dimension format of the input image. If not provided, it will be inferred.
248
+ """
249
+ if "longest_edge" in size:
250
+ size = (size["longest_edge"], size["longest_edge"])
251
+ elif "height" in size and "width" in size:
252
+ size = (size["height"], size["width"])
253
+ else:
254
+ raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
255
+
256
+ if "height" in patch_size and "width" in patch_size:
257
+ patch_size = (patch_size["height"], patch_size["width"])
258
+ else:
259
+ raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
260
+
261
+ output_size = get_resize_output_image_size(
262
+ image,
263
+ size=size,
264
+ patch_size=patch_size,
265
+ input_data_format=input_data_format,
266
+ )
267
+ return resize(
268
+ image,
269
+ size=output_size,
270
+ resample=resample,
271
+ data_format=data_format,
272
+ input_data_format=input_data_format,
273
+ **kwargs,
274
+ )
275
+
276
+ def _pad_for_batching(
277
+ self,
278
+ pixel_values: List[np.ndarray],
279
+ image_sizes: List[List[int]],
280
+ data_format: Optional[Union[str, ChannelDimension]] = None,
281
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
282
+ ):
283
+ """
284
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
285
+ Args:
286
+ pixel_values (`List[np.ndarray]`):
287
+ An array of pixel values of each images of shape (`batch_size`, `height`, `width`, `channels`)
288
+ image_sizes (`List[List[int]]`):
289
+ A list of sizes for each image in `pixel_values` in (height, width) format.
290
+ data_format (`str` or `ChannelDimension`, *optional*):
291
+ The channel dimension format for the output image. Can be one of:
292
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
293
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
294
+ If unset, will use same as the input image.
295
+ input_data_format (`str` or `ChannelDimension`, *optional*):
296
+ The channel dimension format for the input image. Can be one of:
297
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
298
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
299
+ If unset, will use the inferred format of the input image.
300
+ Returns:
301
+ List[`np.ndarray`]: The padded images.
302
+ """
303
+
304
+ max_shape = (
305
+ max([size[0] for size in image_sizes]),
306
+ max([size[1] for size in image_sizes]),
307
+ )
308
+ pixel_values = [
309
+ pad(
310
+ image,
311
+ padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])),
312
+ data_format=data_format,
313
+ input_data_format=input_data_format,
314
+ )
315
+ for image, size in zip(pixel_values, image_sizes)
316
+ ]
317
+ return pixel_values
318
+
319
+ def preprocess(
320
+ self,
321
+ images: ImageInput,
322
+ do_resize: Optional[bool] = None,
323
+ size: Dict[str, int] = None,
324
+ patch_size: Dict[str, int] = None,
325
+ resample: PILImageResampling = None,
326
+ do_rescale: Optional[bool] = None,
327
+ rescale_factor: Optional[float] = None,
328
+ do_normalize: Optional[bool] = None,
329
+ image_mean: Optional[Union[float, List[float]]] = None,
330
+ image_std: Optional[Union[float, List[float]]] = None,
331
+ do_convert_rgb: Optional[bool] = None,
332
+ return_tensors: Optional[Union[str, TensorType]] = None,
333
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
334
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
335
+ **kwargs,
336
+ ) -> PIL.Image.Image:
337
+ """
338
+ Preprocess an image or batch of images.
339
+
340
+ Args:
341
+ images (`ImageInput`):
342
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
343
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
344
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
345
+ Whether to resize the image.
346
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
347
+ Describes the maximum input dimensions to the model.
348
+ patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
349
+ Patch size in the model. Used to calculate the image after resizing.
350
+ resample (`int`, *optional*, defaults to `self.resample`):
351
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
352
+ has an effect if `do_resize` is set to `True`.
353
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
354
+ Whether to rescale the image.
355
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
356
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
357
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
358
+ Whether to normalize the image.
359
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
360
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
361
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
362
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
363
+ `True`.
364
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
365
+ Whether to convert the image to RGB.
366
+ return_tensors (`str` or `TensorType`, *optional*):
367
+ The type of tensors to return. Can be one of:
368
+ - Unset: Return a list of `np.ndarray`.
369
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
370
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
371
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
372
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
373
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
374
+ The channel dimension format for the output image. Can be one of:
375
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
376
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
377
+ - Unset: Use the channel dimension format of the input image.
378
+ input_data_format (`ChannelDimension` or `str`, *optional*):
379
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
380
+ from the input image. Can be one of:
381
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
382
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
383
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
384
+ """
385
+ patch_size = patch_size if patch_size is not None else self.patch_size
386
+ patch_size = get_size_dict(patch_size, default_to_square=True)
387
+
388
+ do_resize = do_resize if do_resize is not None else self.do_resize
389
+ size = size if size is not None else self.size
390
+ resample = resample if resample is not None else self.resample
391
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
392
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
393
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
394
+ image_mean = image_mean if image_mean is not None else self.image_mean
395
+ image_std = image_std if image_std is not None else self.image_std
396
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
397
+
398
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
399
+
400
+ images = make_list_of_images(images)
401
+
402
+ if not valid_images(images[0]):
403
+ raise ValueError(
404
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
405
+ "torch.Tensor, tf.Tensor or jax.ndarray."
406
+ )
407
+
408
+ validate_preprocess_arguments(
409
+ do_rescale=do_rescale,
410
+ rescale_factor=rescale_factor,
411
+ do_normalize=do_normalize,
412
+ image_mean=image_mean,
413
+ image_std=image_std,
414
+ do_resize=do_resize,
415
+ size=size,
416
+ resample=resample,
417
+ )
418
+
419
+ if do_convert_rgb:
420
+ images = [convert_to_rgb(image) for image in images]
421
+
422
+ # All transformations expect numpy arrays.
423
+ images = [to_numpy_array(image) for image in images]
424
+
425
+ if do_rescale and is_scaled_image(images[0]):
426
+ logger.warning_once(
427
+ "It looks like you are trying to rescale already rescaled images. If the input"
428
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
429
+ )
430
+
431
+ if input_data_format is None:
432
+ # We assume that all images have the same channel dimension format.
433
+ input_data_format = infer_channel_dimension_format(images[0])
434
+
435
+ batch_images = []
436
+ batch_image_sizes = []
437
+ for image in images:
438
+ if do_resize:
439
+ image = self.resize(
440
+ image=image,
441
+ size=size,
442
+ patch_size=patch_size,
443
+ resample=resample,
444
+ input_data_format=input_data_format,
445
+ )
446
+
447
+ if do_rescale:
448
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
449
+
450
+ if do_normalize:
451
+ image = self.normalize(
452
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
453
+ )
454
+
455
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
456
+
457
+ batch_images.append(image)
458
+ batch_image_sizes.append(get_image_size(image, data_format))
459
+
460
+ pixel_values = self._pad_for_batching(
461
+ pixel_values=batch_images,
462
+ image_sizes=batch_image_sizes,
463
+ input_data_format=data_format,
464
+ data_format=data_format,
465
+ )
466
+
467
+ return BatchFeature(
468
+ data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors
469
+ )
470
+
471
+
472
+ __all__ = ["PixtralImageProcessor"]
docs/transformers/build/lib/transformers/models/pixtral/modeling_pixtral.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Pixtral model."""
16
+
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+
23
+ from ... import PreTrainedModel
24
+ from ...activations import ACT2FN
25
+ from ...modeling_outputs import BaseModelOutput
26
+ from ...modeling_rope_utils import dynamic_rope_update
27
+ from ...utils import (
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ logging,
31
+ )
32
+ from .configuration_pixtral import PixtralVisionConfig
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ def position_ids_in_meshgrid(patch_embeds_list, max_width):
39
+ positions = []
40
+ for patch in patch_embeds_list:
41
+ height, width = patch.shape[-2:]
42
+ mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
43
+ h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
44
+ ids = h_grid * max_width + v_grid
45
+ positions.append(ids[:, 0])
46
+ return torch.cat(positions)
47
+
48
+
49
+ class PixtralRotaryEmbedding(nn.Module):
50
+ """
51
+ The key with pixtral embedding is just that you have a frequency for each pixel positions.
52
+ If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
53
+ is given by indexing the pre_computed frequency on the width and height.
54
+
55
+ What you output is of dimension (batch, height * width, dim) with dim the embed dim.
56
+
57
+ This simply means that for each image hidden state, you are going to add
58
+ a corresponding positional embedding, based on its index in the grid.
59
+ """
60
+
61
+ def __init__(self, config, device=None):
62
+ super().__init__()
63
+ self.rope_type = "default"
64
+ self.dim = config.head_dim
65
+ self.base = config.rope_theta
66
+ max_patches_per_side = config.image_size // config.patch_size
67
+ freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
68
+
69
+ h = torch.arange(max_patches_per_side, device=freqs.device)
70
+ w = torch.arange(max_patches_per_side, device=freqs.device)
71
+
72
+ freqs_h = torch.outer(h, freqs[::2]).float()
73
+ freqs_w = torch.outer(w, freqs[1::2]).float()
74
+ inv_freq = torch.cat(
75
+ [
76
+ freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
77
+ freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
78
+ ],
79
+ dim=-1,
80
+ ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
81
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
82
+
83
+ # TODO maybe make it torch compatible later on. We can also just slice
84
+ self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False)
85
+
86
+ @torch.no_grad()
87
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
88
+ def forward(self, x, position_ids):
89
+ freqs = self.inv_freq[position_ids]
90
+
91
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
92
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
93
+ emb = freqs
94
+ cos = emb.cos()
95
+ sin = emb.sin()
96
+
97
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
98
+
99
+
100
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
101
+ def rotate_half(x):
102
+ """Rotates half the hidden dims of the input."""
103
+ x1 = x[..., : x.shape[-1] // 2]
104
+ x2 = x[..., x.shape[-1] // 2 :]
105
+ return torch.cat((-x2, x1), dim=-1)
106
+
107
+
108
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
109
+ """Applies Rotary Position Embedding to the query and key tensors.
110
+
111
+ Args:
112
+ q (`torch.Tensor`): The query tensor.
113
+ k (`torch.Tensor`): The key tensor.
114
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
115
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
116
+ position_ids (`torch.Tensor`, *optional*):
117
+ Deprecated and unused.
118
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
119
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
120
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
121
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
122
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
123
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
124
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
125
+ Returns:
126
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
127
+ """
128
+ cos = cos.unsqueeze(unsqueeze_dim)
129
+ sin = sin.unsqueeze(unsqueeze_dim)
130
+ q_embed = (q * cos) + (rotate_half(q) * sin)
131
+ k_embed = (k * cos) + (rotate_half(k) * sin)
132
+ return q_embed, k_embed
133
+
134
+
135
+ class PixtralAttention(nn.Module):
136
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
137
+
138
+ def __init__(self, config):
139
+ super().__init__()
140
+ self.config = config
141
+ self.embed_dim = config.hidden_size
142
+ self.num_heads = config.num_attention_heads
143
+ self.head_dim = self.embed_dim // self.num_heads
144
+
145
+ self.scale = self.head_dim**-0.5
146
+ self.dropout = config.attention_dropout
147
+
148
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
149
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
150
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
151
+ self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
152
+
153
+ def forward(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ attention_mask: Optional[torch.Tensor] = None,
157
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
158
+ output_attentions: Optional[bool] = False,
159
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
160
+ """Input shape: Batch x Time x Channel"""
161
+
162
+ batch_size, patches, _ = hidden_states.size()
163
+
164
+ query_states = self.q_proj(hidden_states)
165
+ key_states = self.k_proj(hidden_states)
166
+ value_states = self.v_proj(hidden_states)
167
+
168
+ query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
169
+ key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
170
+ value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
171
+
172
+ cos, sin = position_embeddings
173
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
174
+
175
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
176
+
177
+ if attention_mask is not None:
178
+ attn_weights = attn_weights + attention_mask
179
+
180
+ # upcast attention to fp32
181
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
182
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
183
+ attn_output = torch.matmul(attn_weights, value_states)
184
+
185
+ attn_output = attn_output.transpose(1, 2).contiguous()
186
+ attn_output = attn_output.reshape(batch_size, patches, -1)
187
+
188
+ attn_output = self.o_proj(attn_output)
189
+
190
+ return attn_output, attn_weights
191
+
192
+
193
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
194
+ class PixtralMLP(nn.Module):
195
+ def __init__(self, config):
196
+ super().__init__()
197
+ self.config = config
198
+ self.hidden_size = config.hidden_size
199
+ self.intermediate_size = config.intermediate_size
200
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
201
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
202
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
203
+ self.act_fn = ACT2FN[config.hidden_act]
204
+
205
+ def forward(self, x):
206
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
207
+ return down_proj
208
+
209
+
210
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
211
+ class PixtralRMSNorm(nn.Module):
212
+ def __init__(self, hidden_size, eps=1e-6):
213
+ """
214
+ PixtralRMSNorm is equivalent to T5LayerNorm
215
+ """
216
+ super().__init__()
217
+ self.weight = nn.Parameter(torch.ones(hidden_size))
218
+ self.variance_epsilon = eps
219
+
220
+ def forward(self, hidden_states):
221
+ input_dtype = hidden_states.dtype
222
+ hidden_states = hidden_states.to(torch.float32)
223
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
224
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
225
+ return self.weight * hidden_states.to(input_dtype)
226
+
227
+ def extra_repr(self):
228
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
229
+
230
+
231
+ class PixtralAttentionLayer(nn.Module):
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
235
+ self.feed_forward = PixtralMLP(config)
236
+ self.attention = PixtralAttention(config)
237
+ self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states: torch.Tensor,
242
+ attention_mask: torch.Tensor,
243
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
244
+ output_attentions: Optional[bool] = None,
245
+ ) -> Tuple[torch.FloatTensor]:
246
+ """
247
+ Args:
248
+ hidden_states (`torch.FloatTensor`):
249
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
250
+ attention_mask (`torch.FloatTensor`):
251
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
252
+ output_attentions (`bool`, *optional*, defaults to `False`):
253
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
254
+ returned tensors for more detail.
255
+ """
256
+ residual = hidden_states
257
+
258
+ hidden_states = self.attention_norm(hidden_states)
259
+ hidden_states, attn_weights = self.attention(
260
+ hidden_states=hidden_states,
261
+ attention_mask=attention_mask,
262
+ position_embeddings=position_embeddings,
263
+ output_attentions=output_attentions,
264
+ )
265
+ hidden_states = residual + hidden_states
266
+
267
+ residual = hidden_states
268
+ hidden_states = self.ffn_norm(hidden_states)
269
+ hidden_states = self.feed_forward(hidden_states)
270
+ hidden_states = residual + hidden_states
271
+
272
+ outputs = (hidden_states,)
273
+
274
+ if output_attentions:
275
+ outputs += (attn_weights,)
276
+ return outputs
277
+
278
+
279
+ class PixtralTransformer(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.config = config
283
+ self.layers = torch.nn.ModuleList()
284
+ for _ in range(config.num_hidden_layers):
285
+ self.layers.append(PixtralAttentionLayer(config))
286
+ self.gradient_checkpointing = False
287
+
288
+ def forward(
289
+ self,
290
+ inputs_embeds,
291
+ attention_mask: Optional[torch.Tensor] = None,
292
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
293
+ output_attentions: Optional[bool] = None,
294
+ output_hidden_states: Optional[bool] = None,
295
+ return_dict: Optional[bool] = None,
296
+ ) -> Union[Tuple, BaseModelOutput]:
297
+ r"""
298
+ Args:
299
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
300
+ Embeddings which serve as input to the Transformer.
301
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
302
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
303
+
304
+ - 1 for tokens that are **not masked**,
305
+ - 0 for tokens that are **masked**.
306
+
307
+ [What are attention masks?](../glossary#attention-mask)
308
+ output_attentions (`bool`, *optional*):
309
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
310
+ returned tensors for more detail.
311
+ output_hidden_states (`bool`, *optional*):
312
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
313
+ for more detail.
314
+ return_dict (`bool`, *optional*):
315
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
316
+ """
317
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
318
+ output_hidden_states = (
319
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
320
+ )
321
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
322
+
323
+ encoder_states = () if output_hidden_states else None
324
+ all_attentions = () if output_attentions else None
325
+
326
+ hidden_states = inputs_embeds
327
+ for encoder_layer in self.layers:
328
+ if output_hidden_states:
329
+ encoder_states = encoder_states + (hidden_states,)
330
+ if self.gradient_checkpointing and self.training:
331
+ layer_outputs = self._gradient_checkpointing_func(
332
+ encoder_layer.__call__,
333
+ hidden_states,
334
+ attention_mask,
335
+ position_embeddings,
336
+ output_attentions,
337
+ )
338
+ else:
339
+ layer_outputs = encoder_layer(
340
+ hidden_states,
341
+ attention_mask,
342
+ position_embeddings=position_embeddings,
343
+ output_attentions=output_attentions,
344
+ )
345
+
346
+ hidden_states = layer_outputs[0]
347
+
348
+ if output_attentions:
349
+ all_attentions = all_attentions + (layer_outputs[1],)
350
+
351
+ if output_hidden_states:
352
+ encoder_states = encoder_states + (hidden_states,)
353
+
354
+ if not return_dict:
355
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
356
+ return BaseModelOutput(
357
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
358
+ )
359
+
360
+
361
+ PIXTRAL_START_DOCSTRING = r"""
362
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
363
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
364
+ etc.)
365
+
366
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
367
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
368
+ and behavior.
369
+
370
+ Parameters:
371
+ config ([`PixtralVisionConfig`]):
372
+ Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
373
+ load the weights associated with the model, only the configuration. Check out the
374
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
375
+ """
376
+
377
+
378
+ class PixtralPreTrainedModel(PreTrainedModel):
379
+ config_class = PixtralVisionConfig
380
+ base_model_prefix = "model"
381
+ main_input_name = "pixel_values"
382
+ supports_gradient_checkpointing = True
383
+ _no_split_modules = ["PixtralAttentionLayer"]
384
+
385
+ def _init_weights(self, module):
386
+ std = self.config.initializer_range
387
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
388
+ module.weight.data.normal_(mean=0.0, std=std)
389
+ if module.bias is not None:
390
+ module.bias.data.zero_()
391
+ elif isinstance(module, PixtralRMSNorm):
392
+ module.weight.data.fill_(1.0)
393
+
394
+
395
+ PIXTRAL_INPUTS_DOCSTRING = r"""
396
+ Args:
397
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
398
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`]
399
+ for details.
400
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
401
+ The sizes of the images in the batch, being (height, width) for each image.
402
+ output_attentions (`bool`, *optional*):
403
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
404
+ tensors for more detail.
405
+ output_hidden_states (`bool`, *optional*):
406
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
407
+ more detail.
408
+ return_dict (`bool`, *optional*):
409
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
410
+ """
411
+
412
+
413
+ def generate_block_attention_mask(patch_embeds_list, tensor):
414
+ dtype = tensor.dtype
415
+ device = tensor.device
416
+ seq_len = tensor.shape[1]
417
+ d_min = torch.finfo(dtype).min
418
+ causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
419
+
420
+ block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
421
+ block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
422
+ for start, end in zip(block_start_idx, block_end_idx):
423
+ causal_mask[start:end, start:end] = 0
424
+
425
+ causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
426
+ return causal_mask
427
+
428
+
429
+ @add_start_docstrings(
430
+ "The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.",
431
+ PIXTRAL_START_DOCSTRING,
432
+ )
433
+ class PixtralVisionModel(PixtralPreTrainedModel):
434
+ base_model_prefix = "vision_encoder"
435
+
436
+ def __init__(self, config):
437
+ super().__init__(config)
438
+ self.config = config
439
+ self.patch_conv = nn.Conv2d(
440
+ in_channels=config.num_channels,
441
+ out_channels=config.hidden_size,
442
+ kernel_size=config.patch_size,
443
+ stride=config.patch_size,
444
+ bias=False,
445
+ )
446
+ self.patch_size = config.patch_size
447
+ self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
448
+ self.transformer = PixtralTransformer(config)
449
+ self.patch_positional_embedding = PixtralRotaryEmbedding(config)
450
+
451
+ self.post_init()
452
+
453
+ def get_input_embeddings(self):
454
+ return self.patch_conv
455
+
456
+ @add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
457
+ def forward(
458
+ self,
459
+ pixel_values: torch.Tensor,
460
+ image_sizes: torch.Tensor,
461
+ output_hidden_states: Optional[bool] = None,
462
+ output_attentions: Optional[bool] = None,
463
+ return_dict: Optional[bool] = None,
464
+ *args,
465
+ **kwargs,
466
+ ) -> Union[Tuple, BaseModelOutput]:
467
+ """
468
+ Returns:
469
+ pixel_values: tensor of token features for
470
+ all tokens of all images of shape (N_toks, D)
471
+ """
472
+ # pass images through initial convolution independently
473
+ patch_embeds = self.patch_conv(pixel_values)
474
+ patch_embeds_list = [
475
+ embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)]
476
+ for embed, size in zip(patch_embeds, image_sizes)
477
+ ]
478
+
479
+ # flatten to a single sequence
480
+ patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
481
+ patch_embeds = self.ln_pre(patch_embeds)
482
+
483
+ # positional embeddings
484
+ position_ids = position_ids_in_meshgrid(
485
+ patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
486
+ )
487
+ position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
488
+
489
+ attention_mask = generate_block_attention_mask(
490
+ [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
491
+ )
492
+
493
+ out = self.transformer(
494
+ patch_embeds,
495
+ attention_mask=attention_mask,
496
+ position_embeddings=position_embeddings,
497
+ output_hidden_states=output_hidden_states,
498
+ output_attentions=output_attentions,
499
+ return_dict=return_dict,
500
+ )
501
+
502
+ return out
503
+
504
+
505
+ __all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"]