Add files using upload-large-folder tool
Browse files- .github/ISSUE_TEMPLATE/custom.md +8 -0
- .github/SECURITY.md +3 -0
- .github/workflows/lint.yaml +22 -0
- .github/workflows/publish.yaml +29 -0
- .ipynb_checkpoints/GRPO_TEST-checkpoint.jsonl +0 -0
- .ipynb_checkpoints/GRPOtrain-checkpoint.sh +38 -0
- .ipynb_checkpoints/README_CN-checkpoint.md +413 -0
- .ipynb_checkpoints/VLLM-checkpoint.sh +7 -0
- .ipynb_checkpoints/clean_transcripts-checkpoint.py +95 -0
- .ipynb_checkpoints/compare_scores-checkpoint.py +96 -0
- .ipynb_checkpoints/dialogue_length_ranges-checkpoint.png +0 -0
- .ipynb_checkpoints/infer-checkpoint.py +63 -0
- GRPO_TRAIN.jsonl +0 -0
- docs/transformers/build/lib/transformers/models/phimoe/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py +155 -0
- docs/transformers/build/lib/transformers/models/pix2struct/image_processing_pix2struct.py +464 -0
- docs/transformers/build/lib/transformers/models/pixtral/__init__.py +30 -0
- docs/transformers/build/lib/transformers/models/pixtral/configuration_pixtral.py +106 -0
- docs/transformers/build/lib/transformers/models/pixtral/image_processing_pixtral.py +472 -0
- docs/transformers/build/lib/transformers/models/pixtral/modeling_pixtral.py +505 -0
.github/ISSUE_TEMPLATE/custom.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Custom issue template
|
3 |
+
about: Describe this issue template's purpose here.
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
.github/SECURITY.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Reporting Security Issues
|
2 |
+
|
3 |
+
Usually security issues of a deep learning project come from non-standard 3rd packages or continuous running services. If you are suffering from security issues from our project, please consider reporting to us. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions.
|
.github/workflows/lint.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Lint test
|
2 |
+
|
3 |
+
on: [push, pull_request]
|
4 |
+
|
5 |
+
concurrency:
|
6 |
+
group: ${{ github.workflow }}-${{ github.ref }}
|
7 |
+
cancel-in-progress: true
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
lint:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v2
|
14 |
+
- name: Set up Python 3.10
|
15 |
+
uses: actions/setup-python@v2
|
16 |
+
with:
|
17 |
+
python-version: '3.10'
|
18 |
+
- name: Install pre-commit hook
|
19 |
+
run: |
|
20 |
+
pip install pre-commit
|
21 |
+
- name: Linting
|
22 |
+
run: pre-commit run --all-files
|
.github/workflows/publish.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: release
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags:
|
6 |
+
- 'v**'
|
7 |
+
|
8 |
+
concurrency:
|
9 |
+
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
10 |
+
cancel-in-progress: true
|
11 |
+
|
12 |
+
jobs:
|
13 |
+
build-n-publish:
|
14 |
+
runs-on: ubuntu-22.04
|
15 |
+
#if: startsWith(github.event.ref, 'refs/tags')
|
16 |
+
steps:
|
17 |
+
- uses: actions/checkout@v2
|
18 |
+
- name: Set up Python 3.10
|
19 |
+
uses: actions/setup-python@v2
|
20 |
+
with:
|
21 |
+
python-version: '3.10'
|
22 |
+
- name: Install wheel
|
23 |
+
run: pip install wheel packaging setuptools==69.5.1
|
24 |
+
- name: Build ModelScope Swift
|
25 |
+
run: python setup.py sdist bdist_wheel
|
26 |
+
- name: Publish package to PyPI
|
27 |
+
run: |
|
28 |
+
pip install twine
|
29 |
+
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
.ipynb_checkpoints/GRPO_TEST-checkpoint.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
.ipynb_checkpoints/GRPOtrain-checkpoint.sh
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
WANDB_API_KEY="a7ab128385681b17ad156ad0d8c81ba3e2296164" \
|
2 |
+
CUDA_VISIBLE_DEVICES=0,1 \
|
3 |
+
NPROC_PER_NODE=2 \
|
4 |
+
swift rlhf \
|
5 |
+
--rlhf_type grpo \
|
6 |
+
--model /root/autodl-tmp/output_7B_FULL_cotSFT/v11-20250721-183605/checkpoint-330 \
|
7 |
+
--external_plugins GRPO/Reward.py \
|
8 |
+
--reward_funcs external_r1v_acc external_r1v_format_acc \
|
9 |
+
--use_vllm false \
|
10 |
+
--train_type full \
|
11 |
+
--torch_dtype bfloat16 \
|
12 |
+
--dataset 'all_dataset_train_resampled_16000.jsonl' \
|
13 |
+
--max_completion_length 512 \
|
14 |
+
--num_train_epochs 2 \
|
15 |
+
--per_device_train_batch_size 2 \
|
16 |
+
--per_device_eval_batch_size 2 \
|
17 |
+
--learning_rate 1e-6 \
|
18 |
+
--gradient_accumulation_steps 2 \
|
19 |
+
--save_strategy 'steps' \
|
20 |
+
--eval_strategy 'steps' \
|
21 |
+
--eval_steps 290 \
|
22 |
+
--save_steps 290 \
|
23 |
+
--save_total_limit 5 \
|
24 |
+
--logging_steps 5 \
|
25 |
+
--output_dir /root/autodl-tmp/output_7B_GRPO \
|
26 |
+
--warmup_ratio 0.01 \
|
27 |
+
--dataloader_num_workers 1 \
|
28 |
+
--num_generations 2 \
|
29 |
+
--temperature 1.0 \
|
30 |
+
--log_completions true \
|
31 |
+
--num_iterations 1 \
|
32 |
+
--async_generate false \
|
33 |
+
--beta 0.01 \
|
34 |
+
--deepspeed zero3_offload \
|
35 |
+
--report_to wandb \
|
36 |
+
# --vllm_mode server \
|
37 |
+
# --vllm_server_host 127.0.0.1 \
|
38 |
+
# --vllm_server_port 8000 \
|
.ipynb_checkpoints/README_CN-checkpoint.md
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<br>
|
5 |
+
<img src="asset/banner.png"/>
|
6 |
+
<br>
|
7 |
+
<p>
|
8 |
+
<p align="center">
|
9 |
+
<a href="https://modelscope.cn/home">魔搭社区官网</a>
|
10 |
+
<br>
|
11 |
+
中文  |  <a href="README.md">English</a> 
|
12 |
+
</p>
|
13 |
+
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
17 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
18 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
19 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
20 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
21 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
22 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
<p align="center">
|
26 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
27 |
+
</p>
|
28 |
+
|
29 |
+
<p align="center">
|
30 |
+
<a href="https://arxiv.org/abs/2408.05517">论文</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
31 |
+
</p>
|
32 |
+
|
33 |
+
## 📖 目录
|
34 |
+
- [用户群](#-用户群)
|
35 |
+
- [简介](#-简介)
|
36 |
+
- [新闻](#-新闻)
|
37 |
+
- [安装](#%EF%B8%8F-安装)
|
38 |
+
- [快速开始](#-快速开始)
|
39 |
+
- [如何使用](#-如何使用)
|
40 |
+
- [License](#-license)
|
41 |
+
- [引用](#-引用)
|
42 |
+
|
43 |
+
## ☎ 用户群
|
44 |
+
|
45 |
+
请扫描下面的二维码来加入我们的交流群:
|
46 |
+
|
47 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
|
48 |
+
:-------------------------:|:-------------------------:
|
49 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
50 |
+
|
51 |
+
## 📝 简介
|
52 |
+
🍲 ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL2.5、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。
|
53 |
+
|
54 |
+
🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。
|
55 |
+
|
56 |
+
**为什么选择ms-swift?**
|
57 |
+
- 🍎 **模型类型**:支持500+纯文本大模型、**200+多模态大模型**以及All-to-All全模态模型、序列分类模型、Embedding模型**训练到部署全流程**。
|
58 |
+
- **数据集类型**:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
|
59 |
+
- **硬件支持**:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
|
60 |
+
- 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
|
61 |
+
- **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
|
62 |
+
- **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
|
63 |
+
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
|
64 |
+
- 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
|
65 |
+
- **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
|
66 |
+
- **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
|
67 |
+
- 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
|
68 |
+
- **推理加速**:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
|
69 |
+
- **模型评测**:以EvalScope作为评测后端,支持100+评测数据集对纯��本和多模态模型进行评测。
|
70 |
+
- **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
|
71 |
+
|
72 |
+
## 🎉 新闻
|
73 |
+
- 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
|
74 |
+
- 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
|
75 |
+
- 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/train_multi_round.sh)。
|
76 |
+
- 🎁 2025.03.16: 支持了Megatron的并行技术进行训练,请查看[Megatron-SWIFT训练文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html)。
|
77 |
+
- 🎁 2025.03.15: 支持纯文本和多模态模型的embedding模型的微调,请查看[训练脚本](examples/train/embedding)。
|
78 |
+
- 🎁 2025.03.05: 支持GRPO的hybrid模式,4GPU(4*80G)训练72B模型的脚本参考[这里](examples/train/grpo/internal/train_72b_4gpu.sh)。同时支持vllm的tensor并行,训练脚本参考[这里](examples/train/grpo/internal/multi_gpu_mp_colocate.sh)。
|
79 |
+
- 🎁 2025.02.21: GRPO算法支持使用LMDeploy,训练脚本参考[这里](examples/train/grpo/internal/full_lmdeploy.sh)。此外测试了GRPO算法的性能,使用一些tricks使训练速度提高到300%。WanDB表格请查看[这里](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz)。
|
80 |
+
- 🎁 2025.02.21: 支持`swift sample`命令。强化微调脚本参考[这里](docs/source/Instruction/强化微调.md),大模型API蒸馏采样脚本参考[这里](examples/sampler/distill/distill.sh)。
|
81 |
+
- 🔥 2025.02.12: 支持GRPO (Group Relative Policy Optimization) 训练算法,文档参考[这里](docs/source/Instruction/GRPO.md)。
|
82 |
+
- 🎁 2024.12.04: **ms-swift3.0**大版本更新。请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
|
83 |
+
<details><summary>更多</summary>
|
84 |
+
|
85 |
+
- 🎉 2024.08.12: ms-swift论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
|
86 |
+
- 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
|
87 |
+
- 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
|
88 |
+
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。
|
89 |
+
- 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。
|
90 |
+
</details>
|
91 |
+
|
92 |
+
## 🛠️ 安装
|
93 |
+
使用pip进行安装:
|
94 |
+
```shell
|
95 |
+
pip install ms-swift -U
|
96 |
+
```
|
97 |
+
|
98 |
+
从源代码安装:
|
99 |
+
```shell
|
100 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
101 |
+
|
102 |
+
git clone https://github.com/modelscope/ms-swift.git
|
103 |
+
cd ms-swift
|
104 |
+
pip install -e .
|
105 |
+
```
|
106 |
+
|
107 |
+
运行环境:
|
108 |
+
|
109 |
+
| | 范围 | 推荐 | 备注 |
|
110 |
+
| ------ |--------------| ---- | --|
|
111 |
+
| python | >=3.9 | 3.10 ||
|
112 |
+
| cuda | | cuda12 |使用cpu、npu、mps则无需安装|
|
113 |
+
| torch | >=2.0 | ||
|
114 |
+
| transformers | >=4.33 | 4.51 ||
|
115 |
+
| modelscope | >=1.23 | ||
|
116 |
+
| peft | >=0.11,<0.16 | ||
|
117 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
118 |
+
| deepspeed | >=0.14 | 0.14.5 |训练|
|
119 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
|
120 |
+
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
|
121 |
+
| evalscope | >=0.11 | |评测|
|
122 |
+
|
123 |
+
更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
|
124 |
+
|
125 |
+
|
126 |
+
## 🚀 快速开始
|
127 |
+
|
128 |
+
**10分钟**在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
|
129 |
+
|
130 |
+
### 命令行
|
131 |
+
```shell
|
132 |
+
# 22GB
|
133 |
+
CUDA_VISIBLE_DEVICES=0 \
|
134 |
+
swift sft \
|
135 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
136 |
+
--train_type lora \
|
137 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
138 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
139 |
+
'swift/self-cognition#500' \
|
140 |
+
--torch_dtype bfloat16 \
|
141 |
+
--num_train_epochs 1 \
|
142 |
+
--per_device_train_batch_size 1 \
|
143 |
+
--per_device_eval_batch_size 1 \
|
144 |
+
--learning_rate 1e-4 \
|
145 |
+
--lora_rank 8 \
|
146 |
+
--lora_alpha 32 \
|
147 |
+
--target_modules all-linear \
|
148 |
+
--gradient_accumulation_steps 16 \
|
149 |
+
--eval_steps 50 \
|
150 |
+
--save_steps 50 \
|
151 |
+
--save_total_limit 2 \
|
152 |
+
--logging_steps 5 \
|
153 |
+
--max_length 2048 \
|
154 |
+
--output_dir output \
|
155 |
+
--system 'You are a helpful assistant.' \
|
156 |
+
--warmup_ratio 0.05 \
|
157 |
+
--dataloader_num_workers 4 \
|
158 |
+
--model_author swift \
|
159 |
+
--model_name swift-robot
|
160 |
+
```
|
161 |
+
|
162 |
+
小贴士:
|
163 |
+
- 如果要使用自定义数据集进行训练,你可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html)组织数据集格式,并指定`--dataset <dataset_path>`。
|
164 |
+
- `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
|
165 |
+
- 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
|
166 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
167 |
+
|
168 |
+
训练完成后,使用以下命令对训练后的权重进行推理:
|
169 |
+
- 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
|
170 |
+
|
171 |
+
```shell
|
172 |
+
# 使用交互式命令行进行推理
|
173 |
+
CUDA_VISIBLE_DEVICES=0 \
|
174 |
+
swift infer \
|
175 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
176 |
+
--stream true \
|
177 |
+
--temperature 0 \
|
178 |
+
--max_new_tokens 2048
|
179 |
+
|
180 |
+
# merge-lora并使用vLLM进行推理加速
|
181 |
+
CUDA_VISIBLE_DEVICES=0 \
|
182 |
+
swift infer \
|
183 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
184 |
+
--stream true \
|
185 |
+
--merge_lora true \
|
186 |
+
--infer_backend vllm \
|
187 |
+
--max_model_len 8192 \
|
188 |
+
--temperature 0 \
|
189 |
+
--max_new_tokens 2048
|
190 |
+
```
|
191 |
+
|
192 |
+
最后,使用以下命令将模型推送到ModelScope:
|
193 |
+
```shell
|
194 |
+
CUDA_VISIBLE_DEVICES=0 \
|
195 |
+
swift export \
|
196 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
197 |
+
--push_to_hub true \
|
198 |
+
--hub_model_id '<your-model-id>' \
|
199 |
+
--hub_token '<your-sdk-token>' \
|
200 |
+
--use_hf false
|
201 |
+
```
|
202 |
+
|
203 |
+
### Web-UI
|
204 |
+
|
205 |
+
Web-UI是基于gradio界面技术的**零门槛**训练、部署界面方案,具体可以查看[这里](https://swift.readthedocs.io/zh-cn/latest/GetStarted/Web-UI.html)。
|
206 |
+
|
207 |
+
```shell
|
208 |
+
swift web-ui
|
209 |
+
```
|
210 |
+

|
211 |
+
|
212 |
+
### 使用Python
|
213 |
+
ms-swift也支持使用python的方式进行训练和推理。下面给出训练和推理的**伪代码**,具体可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)。
|
214 |
+
|
215 |
+
训练:
|
216 |
+
```python
|
217 |
+
# 获取模型和template,并加入可训练的LoRA模块
|
218 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
219 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
220 |
+
model = Swift.prepare_model(model, lora_config)
|
221 |
+
|
222 |
+
# 下载并载入数据集,并将文本encode成tokens
|
223 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
224 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
225 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
226 |
+
|
227 |
+
# 进行训练
|
228 |
+
trainer = Seq2SeqTrainer(
|
229 |
+
model=model,
|
230 |
+
args=training_args,
|
231 |
+
data_collator=template.data_collator,
|
232 |
+
train_dataset=train_dataset,
|
233 |
+
eval_dataset=val_dataset,
|
234 |
+
template=template,
|
235 |
+
)
|
236 |
+
trainer.train()
|
237 |
+
```
|
238 |
+
|
239 |
+
推理:
|
240 |
+
```python
|
241 |
+
# 使用原生pytorch引擎进行推理
|
242 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
243 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
244 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
245 |
+
|
246 |
+
resp_list = engine.infer([infer_request], request_config)
|
247 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
248 |
+
```
|
249 |
+
|
250 |
+
## ✨ 如何使用
|
251 |
+
|
252 |
+
这里给出使用ms-swift进行训练到部署到最简示例,具体可以查看[examples](https://github.com/modelscope/ms-swift/tree/main/examples)。
|
253 |
+
|
254 |
+
- 若想使用其他模型或者数据集(含多模态模型和数据集),你只需要修改`--model`指定对应模型的id或者path,修改`--dataset`指定对应数据集的id或者path即可。
|
255 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
256 |
+
|
257 |
+
| 常用链接 |
|
258 |
+
| ------ |
|
259 |
+
| [🔥命令行参数](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html) |
|
260 |
+
| [支持的模型和数据集](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
261 |
+
| [自定义模型](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B.html), [🔥自定义数据集](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
262 |
+
| [大模型教程](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
263 |
+
|
264 |
+
### 训练
|
265 |
+
支持的训练方法:
|
266 |
+
|
267 |
+
| 方法 | 全参数 | LoRA | QLoRA | Deepspeed | 多机 | 多模态 |
|
268 |
+
| ------ | ------ |---------------------------------------------------------------------------------------------| ----- | ------ | ------ |----------------------------------------------------------------------------------------------|
|
269 |
+
| 预训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
270 |
+
| 指令监督微调 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
271 |
+
| DPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
272 |
+
| GRPO训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
273 |
+
| 奖励模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
274 |
+
| PPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
275 |
+
| KTO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
276 |
+
| CPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
277 |
+
| SimPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
278 |
+
| ORPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
279 |
+
| 分类模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
280 |
+
| Embedding模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
281 |
+
|
282 |
+
|
283 |
+
预训练:
|
284 |
+
```shell
|
285 |
+
# 8*A100
|
286 |
+
NPROC_PER_NODE=8 \
|
287 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
288 |
+
swift pt \
|
289 |
+
--model Qwen/Qwen2.5-7B \
|
290 |
+
--dataset swift/chinese-c4 \
|
291 |
+
--streaming true \
|
292 |
+
--train_type full \
|
293 |
+
--deepspeed zero2 \
|
294 |
+
--output_dir output \
|
295 |
+
--max_steps 10000 \
|
296 |
+
...
|
297 |
+
```
|
298 |
+
|
299 |
+
微调:
|
300 |
+
```shell
|
301 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
302 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
303 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
304 |
+
--train_type lora \
|
305 |
+
--output_dir output \
|
306 |
+
...
|
307 |
+
```
|
308 |
+
|
309 |
+
RLHF:
|
310 |
+
```shell
|
311 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
312 |
+
--rlhf_type dpo \
|
313 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
314 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
315 |
+
--train_type lora \
|
316 |
+
--output_dir output \
|
317 |
+
...
|
318 |
+
```
|
319 |
+
|
320 |
+
|
321 |
+
### 推理
|
322 |
+
```shell
|
323 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
324 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
325 |
+
--stream true \
|
326 |
+
--infer_backend pt \
|
327 |
+
--max_new_tokens 2048
|
328 |
+
|
329 |
+
# LoRA
|
330 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
331 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
332 |
+
--adapters swift/test_lora \
|
333 |
+
--stream true \
|
334 |
+
--infer_backend pt \
|
335 |
+
--temperature 0 \
|
336 |
+
--max_new_tokens 2048
|
337 |
+
```
|
338 |
+
|
339 |
+
### 界面推理
|
340 |
+
```shell
|
341 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
342 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
343 |
+
--stream true \
|
344 |
+
--infer_backend pt \
|
345 |
+
--max_new_tokens 2048 \
|
346 |
+
--lang zh
|
347 |
+
```
|
348 |
+
|
349 |
+
### 部署
|
350 |
+
```shell
|
351 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
352 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
353 |
+
--infer_backend vllm
|
354 |
+
```
|
355 |
+
|
356 |
+
### 采样
|
357 |
+
```shell
|
358 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
359 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
360 |
+
--sampler_engine pt \
|
361 |
+
--num_return_sequences 5 \
|
362 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
363 |
+
```
|
364 |
+
|
365 |
+
### 评测
|
366 |
+
```shell
|
367 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
368 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
369 |
+
--infer_backend lmdeploy \
|
370 |
+
--eval_backend OpenCompass \
|
371 |
+
--eval_dataset ARC_c
|
372 |
+
```
|
373 |
+
|
374 |
+
### 量化
|
375 |
+
```shell
|
376 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
377 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
378 |
+
--quant_bits 4 --quant_method awq \
|
379 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
380 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
381 |
+
```
|
382 |
+
|
383 |
+
### 推送模型
|
384 |
+
```shell
|
385 |
+
swift export \
|
386 |
+
--model <model-path> \
|
387 |
+
--push_to_hub true \
|
388 |
+
--hub_model_id '<model-id>' \
|
389 |
+
--hub_token '<sdk-token>'
|
390 |
+
```
|
391 |
+
|
392 |
+
|
393 |
+
## 🏛 License
|
394 |
+
|
395 |
+
本框架使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)进行许可。模型和数据集请查看原资源页面并遵守对应License。
|
396 |
+
|
397 |
+
## 📎 引用
|
398 |
+
|
399 |
+
```bibtex
|
400 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
401 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
402 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
403 |
+
year={2024},
|
404 |
+
eprint={2408.05517},
|
405 |
+
archivePrefix={arXiv},
|
406 |
+
primaryClass={cs.CL},
|
407 |
+
url={https://arxiv.org/abs/2408.05517},
|
408 |
+
}
|
409 |
+
```
|
410 |
+
|
411 |
+
## Star History
|
412 |
+
|
413 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
.ipynb_checkpoints/VLLM-checkpoint.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
|
2 |
+
--model /root/autodl-tmp/output_7B_FULL_cotSFT/v0-20250621-230827/Qwen2.5-Omni-7B \
|
3 |
+
--tokenizer /root/autodl-tmp/output_7B_FULL_cotSFT/v0-20250621-230827/Qwen2.5-Omni-7B \
|
4 |
+
--dtype bfloat16 \
|
5 |
+
--host 127.0.0.1 \
|
6 |
+
--port 8000 \
|
7 |
+
--gpu-memory-utilization 0.9
|
.ipynb_checkpoints/clean_transcripts-checkpoint.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from typing import List, Dict, Tuple
|
4 |
+
|
5 |
+
def parse_timestamp(timestamp: str) -> Tuple[int, int]:
|
6 |
+
"""Convert timestamp string like '00:15' to seconds."""
|
7 |
+
minutes, seconds = map(int, timestamp.split(':'))
|
8 |
+
return minutes * 60 + seconds
|
9 |
+
|
10 |
+
def extract_time_and_speaker(line: str) -> Tuple[Tuple[int, int], str]:
|
11 |
+
"""Extract time range and speaker from a line."""
|
12 |
+
# Extract time range
|
13 |
+
time_match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\] (Speaker [A-Z]):', line)
|
14 |
+
if not time_match:
|
15 |
+
return None, None
|
16 |
+
|
17 |
+
start_time = parse_timestamp(time_match.group(1))
|
18 |
+
end_time = parse_timestamp(time_match.group(2))
|
19 |
+
speaker = time_match.group(3)
|
20 |
+
|
21 |
+
return (start_time, end_time), speaker
|
22 |
+
|
23 |
+
def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
|
24 |
+
"""Check if two time ranges overlap."""
|
25 |
+
start1, end1 = range1
|
26 |
+
start2, end2 = range2
|
27 |
+
return not (end1 <= start2 or end2 <= start1)
|
28 |
+
|
29 |
+
def has_same_speaker_overlap(transcript: str) -> bool:
|
30 |
+
"""Check if a transcript contains overlapping timestamps for the same speaker."""
|
31 |
+
lines = transcript.split('\n')
|
32 |
+
# Dictionary to store time ranges for each speaker
|
33 |
+
speaker_ranges = {}
|
34 |
+
|
35 |
+
for line in lines:
|
36 |
+
if not line.strip():
|
37 |
+
continue
|
38 |
+
|
39 |
+
time_range, speaker = extract_time_and_speaker(line)
|
40 |
+
if time_range is None or speaker is None:
|
41 |
+
continue
|
42 |
+
|
43 |
+
# Check for overlaps with existing ranges of the same speaker
|
44 |
+
if speaker in speaker_ranges:
|
45 |
+
for existing_range in speaker_ranges[speaker]:
|
46 |
+
if has_overlap(time_range, existing_range):
|
47 |
+
return True
|
48 |
+
|
49 |
+
speaker_ranges[speaker].append(time_range)
|
50 |
+
else:
|
51 |
+
speaker_ranges[speaker] = [time_range]
|
52 |
+
|
53 |
+
return False
|
54 |
+
|
55 |
+
def process_file(input_file: str, output_file: str, delete_file: str):
|
56 |
+
"""Process the JSON file and separate entries with same-speaker overlapping timestamps."""
|
57 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
58 |
+
data = json.load(f)
|
59 |
+
|
60 |
+
if isinstance(data, dict):
|
61 |
+
data = [data]
|
62 |
+
|
63 |
+
cleaned_data = []
|
64 |
+
deleted_data = []
|
65 |
+
removed_count = 0
|
66 |
+
|
67 |
+
for entry in data:
|
68 |
+
if 'model_output' in entry:
|
69 |
+
if not has_same_speaker_overlap(entry['model_output']):
|
70 |
+
cleaned_data.append(entry)
|
71 |
+
else:
|
72 |
+
deleted_data.append(entry)
|
73 |
+
removed_count += 1
|
74 |
+
print(f"Removing entry with key: {entry.get('key', 'unknown')}")
|
75 |
+
|
76 |
+
# Save cleaned data
|
77 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
78 |
+
json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
|
79 |
+
|
80 |
+
# Save deleted data
|
81 |
+
with open(delete_file, 'w', encoding='utf-8') as f:
|
82 |
+
json.dump(deleted_data, f, ensure_ascii=False, indent=2)
|
83 |
+
|
84 |
+
print(f"\nProcessing Summary:")
|
85 |
+
print(f"Processed {len(data)} entries")
|
86 |
+
print(f"Removed {removed_count} entries with same-speaker overlapping timestamps")
|
87 |
+
print(f"Remaining entries: {len(cleaned_data)}")
|
88 |
+
|
89 |
+
if __name__ == '__main__':
|
90 |
+
input_file = 'silence_overlaps/transcriptions.json'
|
91 |
+
output_file = 'silence_overlaps/cleaned_transcriptions2.json'
|
92 |
+
delete_file = 'silence_overlaps/delete_transcript2.json'
|
93 |
+
process_file(input_file, output_file, delete_file)
|
94 |
+
print(f"\nCleaned transcriptions have been saved to {output_file}")
|
95 |
+
print(f"Deleted entries have been saved to {delete_file}")
|
.ipynb_checkpoints/compare_scores-checkpoint.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
infer_result_path = '/root/autodl-tmp/output_7B_GRPO/v28-20250722-002940/checkpoint-870/infer_result/53_HH.jsonl'
|
6 |
+
test_path = '/root/autodl-tmp/ms-swift/all_audio_test_50.jsonl'
|
7 |
+
output_path = 'inference_comparison_result.json'
|
8 |
+
|
9 |
+
def extract_overall_score(response_text):
|
10 |
+
match = re.search(r'<overall score>(\d+)</overall score>', response_text)
|
11 |
+
if match:
|
12 |
+
return int(match.group(1))
|
13 |
+
return None
|
14 |
+
|
15 |
+
def main():
|
16 |
+
# 读取infer_result文件,建立audio到score的映射
|
17 |
+
infer_audio2score = {}
|
18 |
+
with open(infer_result_path, 'r', encoding='utf-8') as f:
|
19 |
+
for line in f:
|
20 |
+
data = json.loads(line)
|
21 |
+
score = extract_overall_score(data['response'])
|
22 |
+
audios = tuple(data.get('audios', []))
|
23 |
+
infer_audio2score[audios] = {
|
24 |
+
'score': score,
|
25 |
+
'raw_response': data['response']
|
26 |
+
}
|
27 |
+
|
28 |
+
# 读取test文件,建立audio到solution的映射
|
29 |
+
test_audio2solution = {}
|
30 |
+
with open(test_path, 'r', encoding='utf-8') as f:
|
31 |
+
for line in f:
|
32 |
+
data = json.loads(line)
|
33 |
+
solution = data['solution']
|
34 |
+
audios = tuple(data.get('audios', []))
|
35 |
+
test_audio2solution[audios] = solution
|
36 |
+
|
37 |
+
# 统计和收集错误样本 & 所有推理结果
|
38 |
+
stats_per_class = defaultdict(lambda: {'correct': 0, 'incorrect': 0})
|
39 |
+
incorrect_samples_solution1 = []
|
40 |
+
all_results = []
|
41 |
+
|
42 |
+
total = 0
|
43 |
+
correct = 0
|
44 |
+
|
45 |
+
for audios, solution in test_audio2solution.items():
|
46 |
+
infer_entry = infer_audio2score.get(audios, None)
|
47 |
+
infer_score = infer_entry['score'] if infer_entry else None
|
48 |
+
raw_response = infer_entry['raw_response'] if infer_entry else None
|
49 |
+
match = infer_score == solution
|
50 |
+
|
51 |
+
# 收集所有结果
|
52 |
+
all_results.append({
|
53 |
+
'audios': audios,
|
54 |
+
'gt_solution': solution,
|
55 |
+
'predicted_score': infer_score,
|
56 |
+
'match': match,
|
57 |
+
'response': raw_response
|
58 |
+
})
|
59 |
+
|
60 |
+
if match:
|
61 |
+
correct += 1
|
62 |
+
stats_per_class[solution]['correct'] += 1
|
63 |
+
else:
|
64 |
+
stats_per_class[solution]['incorrect'] += 1
|
65 |
+
if solution == 1:
|
66 |
+
incorrect_samples_solution1.append({
|
67 |
+
'audios': audios,
|
68 |
+
'gt_solution': solution,
|
69 |
+
'predicted_score': infer_score,
|
70 |
+
'response': raw_response
|
71 |
+
})
|
72 |
+
|
73 |
+
total += 1
|
74 |
+
|
75 |
+
# 总体准确率
|
76 |
+
print(f'\nOverall Accuracy: {correct}/{total} = {correct/total:.2%}\n')
|
77 |
+
|
78 |
+
# 每类准确率
|
79 |
+
print("Per-Class Accuracy:")
|
80 |
+
for solution, stats in sorted(stats_per_class.items()):
|
81 |
+
total_class = stats['correct'] + stats['incorrect']
|
82 |
+
accuracy = stats['correct'] / total_class if total_class > 0 else 0.0
|
83 |
+
print(f'Class {solution}: Correct={stats["correct"]}, Incorrect={stats["incorrect"]}, Accuracy={accuracy:.2%}')
|
84 |
+
|
85 |
+
# 列出 solution=1 且预测错误的样本
|
86 |
+
print("\nIncorrect Samples for solution = 1:")
|
87 |
+
for sample in incorrect_samples_solution1:
|
88 |
+
print(json.dumps(sample, indent=2, ensure_ascii=False))
|
89 |
+
|
90 |
+
# 写入所有结果到 JSON 文件
|
91 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
92 |
+
json.dump(all_results, f, indent=2, ensure_ascii=False)
|
93 |
+
print(f"\nAll inference comparison results saved to: {output_path}")
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
main()
|
.ipynb_checkpoints/dialogue_length_ranges-checkpoint.png
ADDED
![]() |
.ipynb_checkpoints/infer-checkpoint.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
3 |
+
|
4 |
+
from swift.llm import PtEngine, RequestConfig, safe_snapshot_download, get_model_tokenizer, get_template, InferRequest
|
5 |
+
import json
|
6 |
+
from transformers import AutoProcessor
|
7 |
+
from swift.tuners import Swift
|
8 |
+
last_model_checkpoint = '/root/autodl-tmp/output_7B_Lora_cotSFT/v2-20250613-111902/checkpoint-3'
|
9 |
+
|
10 |
+
# 模型
|
11 |
+
model_id_or_path = '/root/autodl-tmp/output_7B_Lora_allmission/v2-20250610-190504/checkpoint-1000-merged' # model_id or model_path
|
12 |
+
system = 'You are a helpful assistant.'
|
13 |
+
infer_backend = 'pt'
|
14 |
+
|
15 |
+
# 生成参数
|
16 |
+
max_new_tokens = 2048
|
17 |
+
temperature = 0
|
18 |
+
stream = False
|
19 |
+
template_type = None
|
20 |
+
default_system = system # None: 使用对应模型默认的default_system
|
21 |
+
# 初始化音频处理器
|
22 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path)
|
23 |
+
# 初始化引擎
|
24 |
+
model = Swift.from_pretrained(model, last_model_checkpoint)
|
25 |
+
template_type = template_type or model.model_meta.template
|
26 |
+
template = get_template(template_type, tokenizer, default_system=default_system)
|
27 |
+
engine = PtEngine.from_model_template(model, template, max_batch_size=2)
|
28 |
+
request_config = RequestConfig(max_tokens=8192, temperature=0)
|
29 |
+
|
30 |
+
def load_test_data(json_file):
|
31 |
+
test_requests = []
|
32 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
33 |
+
for line in f:
|
34 |
+
data = json.loads(line.strip())
|
35 |
+
test_requests.append(InferRequest(
|
36 |
+
messages=data['messages'],
|
37 |
+
audios=data['audios']
|
38 |
+
))
|
39 |
+
return test_requests
|
40 |
+
|
41 |
+
def main():
|
42 |
+
# 加载测试数据
|
43 |
+
test_file = 'dataset_allmissiontest.json'
|
44 |
+
infer_requests = load_test_data(test_file)
|
45 |
+
results = []
|
46 |
+
resp_list = engine.infer(infer_requests, request_config)
|
47 |
+
for i, resp in enumerate(resp_list):
|
48 |
+
assistant_content = next((msg['content'] for msg in infer_requests[i].messages if msg['role'] == 'assistant'), None)
|
49 |
+
result = {
|
50 |
+
"index": i,
|
51 |
+
"truth": assistant_content,
|
52 |
+
"response": resp.choices[0].message.content,
|
53 |
+
}
|
54 |
+
results.append(result)
|
55 |
+
print(f'truth{i}: {assistant_content}')
|
56 |
+
print(f'response{i}: {resp.choices[0].message.content}')
|
57 |
+
output_file = 'inference_results.json'
|
58 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
59 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
main()
|
GRPO_TRAIN.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/transformers/build/lib/transformers/models/phimoe/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import TYPE_CHECKING
|
15 |
+
|
16 |
+
from ...utils import _LazyModule
|
17 |
+
from ...utils.import_utils import define_import_structure
|
18 |
+
|
19 |
+
|
20 |
+
if TYPE_CHECKING:
|
21 |
+
from .configuration_phimoe import *
|
22 |
+
from .modeling_phimoe import *
|
23 |
+
|
24 |
+
else:
|
25 |
+
import sys
|
26 |
+
|
27 |
+
_file = globals()["__file__"]
|
28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from flax.traverse_util import flatten_dict
|
21 |
+
from t5x import checkpoints
|
22 |
+
|
23 |
+
from transformers import (
|
24 |
+
AutoTokenizer,
|
25 |
+
Pix2StructConfig,
|
26 |
+
Pix2StructForConditionalGeneration,
|
27 |
+
Pix2StructImageProcessor,
|
28 |
+
Pix2StructProcessor,
|
29 |
+
Pix2StructTextConfig,
|
30 |
+
Pix2StructVisionConfig,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def get_flax_param(t5x_checkpoint_path):
|
35 |
+
flax_params = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
36 |
+
flax_params = flatten_dict(flax_params)
|
37 |
+
return flax_params
|
38 |
+
|
39 |
+
|
40 |
+
def rename_and_convert_flax_params(flax_dict):
|
41 |
+
converted_dict = {}
|
42 |
+
|
43 |
+
CONVERSION_MAPPING = {
|
44 |
+
"token_embedder": "embeddings",
|
45 |
+
"encoder_norm": "layernorm",
|
46 |
+
"kernel": "weight",
|
47 |
+
".out": ".output",
|
48 |
+
"scale": "weight",
|
49 |
+
"embedders_0.pos_embedding": "row_embedder.weight",
|
50 |
+
"embedders_1.pos_embedding": "column_embedder.weight",
|
51 |
+
}
|
52 |
+
|
53 |
+
DECODER_CONVERSION_MAPPING = {
|
54 |
+
"query": "attention.query",
|
55 |
+
"key": "attention.key",
|
56 |
+
"value": "attention.value",
|
57 |
+
"output.dense": "output",
|
58 |
+
"encoder_decoder_attention.o": "encoder_decoder_attention.attention.o",
|
59 |
+
"pre_self_attention_layer_norm": "self_attention.layer_norm",
|
60 |
+
"pre_cross_attention_layer_norm": "encoder_decoder_attention.layer_norm",
|
61 |
+
"mlp.": "mlp.DenseReluDense.",
|
62 |
+
"pre_mlp_layer_norm": "mlp.layer_norm",
|
63 |
+
"self_attention.o": "self_attention.attention.o",
|
64 |
+
"decoder.embeddings.embedding": "decoder.embed_tokens.weight",
|
65 |
+
"decoder.relpos_bias.rel_embedding": "decoder.layer.0.self_attention.attention.relative_attention_bias.weight",
|
66 |
+
"decoder.decoder_norm.weight": "decoder.final_layer_norm.weight",
|
67 |
+
"decoder.logits_dense.weight": "decoder.lm_head.weight",
|
68 |
+
}
|
69 |
+
|
70 |
+
for key in flax_dict.keys():
|
71 |
+
if "target" in key:
|
72 |
+
# remove the first prefix from the key
|
73 |
+
new_key = ".".join(key[1:])
|
74 |
+
|
75 |
+
# rename the key
|
76 |
+
for old, new in CONVERSION_MAPPING.items():
|
77 |
+
new_key = new_key.replace(old, new)
|
78 |
+
|
79 |
+
if "decoder" in new_key:
|
80 |
+
for old, new in DECODER_CONVERSION_MAPPING.items():
|
81 |
+
new_key = new_key.replace(old, new)
|
82 |
+
|
83 |
+
if "layers" in new_key and "decoder" not in new_key:
|
84 |
+
# use regex to replace the layer number
|
85 |
+
new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key)
|
86 |
+
new_key = new_key.replace("encoder", "encoder.encoder")
|
87 |
+
|
88 |
+
elif "layers" in new_key and "decoder" in new_key:
|
89 |
+
# use regex to replace the layer number
|
90 |
+
new_key = re.sub(r"layers_(\d+)", r"layer.\1", new_key)
|
91 |
+
|
92 |
+
converted_dict[new_key] = flax_dict[key]
|
93 |
+
|
94 |
+
converted_torch_dict = {}
|
95 |
+
# convert converted_dict into torch format
|
96 |
+
for key in converted_dict.keys():
|
97 |
+
if ("embed_tokens" not in key) and ("embedder" not in key):
|
98 |
+
converted_torch_dict[key] = torch.from_numpy(converted_dict[key].T)
|
99 |
+
else:
|
100 |
+
converted_torch_dict[key] = torch.from_numpy(converted_dict[key])
|
101 |
+
|
102 |
+
return converted_torch_dict
|
103 |
+
|
104 |
+
|
105 |
+
def convert_pix2struct_original_pytorch_checkpoint_to_hf(
|
106 |
+
t5x_checkpoint_path, pytorch_dump_folder_path, use_large=False, is_vqa=False
|
107 |
+
):
|
108 |
+
flax_params = get_flax_param(t5x_checkpoint_path)
|
109 |
+
|
110 |
+
if not use_large:
|
111 |
+
encoder_config = Pix2StructVisionConfig()
|
112 |
+
decoder_config = Pix2StructTextConfig()
|
113 |
+
else:
|
114 |
+
encoder_config = Pix2StructVisionConfig(
|
115 |
+
hidden_size=1536, d_ff=3968, num_attention_heads=24, num_hidden_layers=18
|
116 |
+
)
|
117 |
+
decoder_config = Pix2StructTextConfig(hidden_size=1536, d_ff=3968, num_heads=24, num_layers=18)
|
118 |
+
config = Pix2StructConfig(
|
119 |
+
vision_config=encoder_config.to_dict(), text_config=decoder_config.to_dict(), is_vqa=is_vqa
|
120 |
+
)
|
121 |
+
|
122 |
+
model = Pix2StructForConditionalGeneration(config)
|
123 |
+
|
124 |
+
torch_params = rename_and_convert_flax_params(flax_params)
|
125 |
+
model.load_state_dict(torch_params)
|
126 |
+
|
127 |
+
tok = AutoTokenizer.from_pretrained("ybelkada/test-pix2struct-tokenizer")
|
128 |
+
image_processor = Pix2StructImageProcessor()
|
129 |
+
processor = Pix2StructProcessor(image_processor=image_processor, tokenizer=tok)
|
130 |
+
|
131 |
+
if use_large:
|
132 |
+
processor.image_processor.max_patches = 4096
|
133 |
+
|
134 |
+
processor.image_processor.is_vqa = True
|
135 |
+
|
136 |
+
# mkdir if needed
|
137 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
138 |
+
|
139 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
140 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
141 |
+
|
142 |
+
print("Model saved in {}".format(pytorch_dump_folder_path))
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
parser = argparse.ArgumentParser()
|
147 |
+
parser.add_argument("--t5x_checkpoint_path", default=None, type=str, help="Path to the original T5x checkpoint.")
|
148 |
+
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
149 |
+
parser.add_argument("--use_large", action="store_true", help="Use large model.")
|
150 |
+
parser.add_argument("--is_vqa", action="store_true", help="Use large model.")
|
151 |
+
args = parser.parse_args()
|
152 |
+
|
153 |
+
convert_pix2struct_original_pytorch_checkpoint_to_hf(
|
154 |
+
args.t5x_checkpoint_path, args.pytorch_dump_folder_path, args.use_large
|
155 |
+
)
|
docs/transformers/build/lib/transformers/models/pix2struct/image_processing_pix2struct.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Image processor class for Pix2Struct."""
|
16 |
+
|
17 |
+
import io
|
18 |
+
import math
|
19 |
+
from typing import Dict, Optional, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
|
24 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
25 |
+
from ...image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image
|
26 |
+
from ...image_utils import (
|
27 |
+
ChannelDimension,
|
28 |
+
ImageInput,
|
29 |
+
get_image_size,
|
30 |
+
infer_channel_dimension_format,
|
31 |
+
make_list_of_images,
|
32 |
+
to_numpy_array,
|
33 |
+
valid_images,
|
34 |
+
)
|
35 |
+
from ...utils import TensorType, is_torch_available, is_vision_available, logging
|
36 |
+
from ...utils.import_utils import requires_backends
|
37 |
+
|
38 |
+
|
39 |
+
if is_vision_available():
|
40 |
+
import textwrap
|
41 |
+
|
42 |
+
from PIL import Image, ImageDraw, ImageFont
|
43 |
+
|
44 |
+
if is_torch_available():
|
45 |
+
import torch
|
46 |
+
|
47 |
+
logger = logging.get_logger(__name__)
|
48 |
+
DEFAULT_FONT_PATH = "ybelkada/fonts"
|
49 |
+
|
50 |
+
|
51 |
+
# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
|
52 |
+
def torch_extract_patches(image_tensor, patch_height, patch_width):
|
53 |
+
"""
|
54 |
+
Utiliy function to extract patches from a given image tensor. Returns a tensor of shape (1, `patch_height`,
|
55 |
+
`patch_width`, `num_channels`x `patch_height` x `patch_width`)
|
56 |
+
|
57 |
+
Args:
|
58 |
+
image_tensor (torch.Tensor):
|
59 |
+
The image tensor to extract patches from.
|
60 |
+
patch_height (int):
|
61 |
+
The height of the patches to extract.
|
62 |
+
patch_width (int):
|
63 |
+
The width of the patches to extract.
|
64 |
+
"""
|
65 |
+
requires_backends(torch_extract_patches, ["torch"])
|
66 |
+
|
67 |
+
image_tensor = image_tensor.unsqueeze(0)
|
68 |
+
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
|
69 |
+
patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1)
|
70 |
+
patches = patches.permute(0, 4, 2, 3, 1).reshape(
|
71 |
+
image_tensor.size(2) // patch_height,
|
72 |
+
image_tensor.size(3) // patch_width,
|
73 |
+
image_tensor.size(1) * patch_height * patch_width,
|
74 |
+
)
|
75 |
+
return patches.unsqueeze(0)
|
76 |
+
|
77 |
+
|
78 |
+
# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L106
|
79 |
+
def render_text(
|
80 |
+
text: str,
|
81 |
+
text_size: int = 36,
|
82 |
+
text_color: str = "black",
|
83 |
+
background_color: str = "white",
|
84 |
+
left_padding: int = 5,
|
85 |
+
right_padding: int = 5,
|
86 |
+
top_padding: int = 5,
|
87 |
+
bottom_padding: int = 5,
|
88 |
+
font_bytes: Optional[bytes] = None,
|
89 |
+
font_path: Optional[str] = None,
|
90 |
+
) -> Image.Image:
|
91 |
+
"""
|
92 |
+
Render text. This script is entirely adapted from the original script that can be found here:
|
93 |
+
https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py
|
94 |
+
|
95 |
+
Args:
|
96 |
+
text (`str`, *optional*, defaults to ):
|
97 |
+
Text to render.
|
98 |
+
text_size (`int`, *optional*, defaults to 36):
|
99 |
+
Size of the text.
|
100 |
+
text_color (`str`, *optional*, defaults to `"black"`):
|
101 |
+
Color of the text.
|
102 |
+
background_color (`str`, *optional*, defaults to `"white"`):
|
103 |
+
Color of the background.
|
104 |
+
left_padding (`int`, *optional*, defaults to 5):
|
105 |
+
Padding on the left.
|
106 |
+
right_padding (`int`, *optional*, defaults to 5):
|
107 |
+
Padding on the right.
|
108 |
+
top_padding (`int`, *optional*, defaults to 5):
|
109 |
+
Padding on the top.
|
110 |
+
bottom_padding (`int`, *optional*, defaults to 5):
|
111 |
+
Padding on the bottom.
|
112 |
+
font_bytes (`bytes`, *optional*):
|
113 |
+
Bytes of the font to use. If `None`, the default font will be used.
|
114 |
+
font_path (`str`, *optional*):
|
115 |
+
Path to the font to use. If `None`, the default font will be used.
|
116 |
+
"""
|
117 |
+
requires_backends(render_text, "vision")
|
118 |
+
# Add new lines so that each line is no more than 80 characters.
|
119 |
+
|
120 |
+
wrapper = textwrap.TextWrapper(width=80)
|
121 |
+
lines = wrapper.wrap(text=text)
|
122 |
+
wrapped_text = "\n".join(lines)
|
123 |
+
|
124 |
+
if font_bytes is not None and font_path is None:
|
125 |
+
font = io.BytesIO(font_bytes)
|
126 |
+
elif font_path is not None:
|
127 |
+
font = font_path
|
128 |
+
else:
|
129 |
+
font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
|
130 |
+
font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
|
131 |
+
|
132 |
+
# Use a temporary canvas to determine the width and height in pixels when
|
133 |
+
# rendering the text.
|
134 |
+
temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
|
135 |
+
_, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
|
136 |
+
|
137 |
+
# Create the actual image with a bit of padding around the text.
|
138 |
+
image_width = text_width + left_padding + right_padding
|
139 |
+
image_height = text_height + top_padding + bottom_padding
|
140 |
+
image = Image.new("RGB", (image_width, image_height), background_color)
|
141 |
+
draw = ImageDraw.Draw(image)
|
142 |
+
draw.text(xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font)
|
143 |
+
return image
|
144 |
+
|
145 |
+
|
146 |
+
# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
|
147 |
+
def render_header(
|
148 |
+
image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs
|
149 |
+
):
|
150 |
+
"""
|
151 |
+
Renders the input text as a header on the input image.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
image (`np.ndarray`):
|
155 |
+
The image to render the header on.
|
156 |
+
header (`str`):
|
157 |
+
The header text.
|
158 |
+
data_format (`Union[ChannelDimension, str]`, *optional*):
|
159 |
+
The data format of the image. Can be either "ChannelDimension.channels_first" or
|
160 |
+
"ChannelDimension.channels_last".
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
`np.ndarray`: The image with the header rendered.
|
164 |
+
"""
|
165 |
+
requires_backends(render_header, "vision")
|
166 |
+
|
167 |
+
# Convert to PIL image if necessary
|
168 |
+
image = to_pil_image(image, input_data_format=input_data_format)
|
169 |
+
|
170 |
+
header_image = render_text(header, **kwargs)
|
171 |
+
new_width = max(header_image.width, image.width)
|
172 |
+
|
173 |
+
new_height = int(image.height * (new_width / image.width))
|
174 |
+
new_header_height = int(header_image.height * (new_width / header_image.width))
|
175 |
+
|
176 |
+
new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
|
177 |
+
new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
|
178 |
+
new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
|
179 |
+
|
180 |
+
# Convert back to the original framework if necessary
|
181 |
+
new_image = to_numpy_array(new_image)
|
182 |
+
|
183 |
+
if infer_channel_dimension_format(new_image) == ChannelDimension.LAST:
|
184 |
+
new_image = to_channel_dimension_format(new_image, ChannelDimension.LAST)
|
185 |
+
|
186 |
+
return new_image
|
187 |
+
|
188 |
+
|
189 |
+
class Pix2StructImageProcessor(BaseImageProcessor):
|
190 |
+
r"""
|
191 |
+
Constructs a Pix2Struct image processor.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
195 |
+
Whether to convert the image to RGB.
|
196 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
197 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
198 |
+
method. According to Pix2Struct paper and code, the image is normalized with its own mean and standard
|
199 |
+
deviation.
|
200 |
+
patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 16, "width": 16}`):
|
201 |
+
The patch size to use for the image. According to Pix2Struct paper and code, the patch size is 16x16.
|
202 |
+
max_patches (`int`, *optional*, defaults to 2048):
|
203 |
+
The maximum number of patches to extract from the image as per the [Pix2Struct
|
204 |
+
paper](https://arxiv.org/pdf/2210.03347.pdf).
|
205 |
+
is_vqa (`bool`, *optional*, defaults to `False`):
|
206 |
+
Whether or not the image processor is for the VQA task. If `True` and `header_text` is passed in, text is
|
207 |
+
rendered onto the input images.
|
208 |
+
"""
|
209 |
+
|
210 |
+
model_input_names = ["flattened_patches"]
|
211 |
+
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
do_convert_rgb: bool = True,
|
215 |
+
do_normalize: bool = True,
|
216 |
+
patch_size: Dict[str, int] = None,
|
217 |
+
max_patches: int = 2048,
|
218 |
+
is_vqa: bool = False,
|
219 |
+
**kwargs,
|
220 |
+
) -> None:
|
221 |
+
super().__init__(**kwargs)
|
222 |
+
self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
|
223 |
+
self.do_normalize = do_normalize
|
224 |
+
self.do_convert_rgb = do_convert_rgb
|
225 |
+
self.max_patches = max_patches
|
226 |
+
self.is_vqa = is_vqa
|
227 |
+
|
228 |
+
def extract_flattened_patches(
|
229 |
+
self,
|
230 |
+
image: np.ndarray,
|
231 |
+
max_patches: int,
|
232 |
+
patch_size: dict,
|
233 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
234 |
+
**kwargs,
|
235 |
+
) -> np.ndarray:
|
236 |
+
"""
|
237 |
+
Extract flattened patches from an image.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
image (`np.ndarray`):
|
241 |
+
Image to extract flattened patches from.
|
242 |
+
max_patches (`int`):
|
243 |
+
Maximum number of patches to extract.
|
244 |
+
patch_size (`dict`):
|
245 |
+
Dictionary containing the patch height and width.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
result (`np.ndarray`):
|
249 |
+
A sequence of `max_patches` flattened patches.
|
250 |
+
"""
|
251 |
+
requires_backends(self.extract_flattened_patches, "torch")
|
252 |
+
|
253 |
+
# convert to torch
|
254 |
+
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
|
255 |
+
image = torch.from_numpy(image)
|
256 |
+
|
257 |
+
patch_height, patch_width = patch_size["height"], patch_size["width"]
|
258 |
+
image_height, image_width = get_image_size(image, ChannelDimension.FIRST)
|
259 |
+
|
260 |
+
# maximize scale s.t.
|
261 |
+
scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
|
262 |
+
num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1)
|
263 |
+
num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1)
|
264 |
+
resized_height = max(num_feasible_rows * patch_height, 1)
|
265 |
+
resized_width = max(num_feasible_cols * patch_width, 1)
|
266 |
+
|
267 |
+
image = torch.nn.functional.interpolate(
|
268 |
+
image.unsqueeze(0),
|
269 |
+
size=(resized_height, resized_width),
|
270 |
+
mode="bilinear",
|
271 |
+
align_corners=False,
|
272 |
+
antialias=True,
|
273 |
+
).squeeze(0)
|
274 |
+
|
275 |
+
# [1, rows, columns, patch_height * patch_width * image_channels]
|
276 |
+
patches = torch_extract_patches(image, patch_height, patch_width)
|
277 |
+
|
278 |
+
patches_shape = patches.shape
|
279 |
+
rows = patches_shape[1]
|
280 |
+
columns = patches_shape[2]
|
281 |
+
depth = patches_shape[3]
|
282 |
+
|
283 |
+
# [rows * columns, patch_height * patch_width * image_channels]
|
284 |
+
patches = patches.reshape([rows * columns, depth])
|
285 |
+
|
286 |
+
# [rows * columns, 1]
|
287 |
+
row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1])
|
288 |
+
col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1])
|
289 |
+
|
290 |
+
# Offset by 1 so the ids do not contain zeros, which represent padding.
|
291 |
+
row_ids += 1
|
292 |
+
col_ids += 1
|
293 |
+
|
294 |
+
# Prepare additional patch features.
|
295 |
+
# [rows * columns, 1]
|
296 |
+
row_ids = row_ids.to(torch.float32)
|
297 |
+
col_ids = col_ids.to(torch.float32)
|
298 |
+
|
299 |
+
# [rows * columns, 2 + patch_height * patch_width * image_channels]
|
300 |
+
result = torch.cat([row_ids, col_ids, patches], -1)
|
301 |
+
|
302 |
+
# [max_patches, 2 + patch_height * patch_width * image_channels]
|
303 |
+
result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float()
|
304 |
+
|
305 |
+
result = to_numpy_array(result)
|
306 |
+
|
307 |
+
return result
|
308 |
+
|
309 |
+
def normalize(
|
310 |
+
self,
|
311 |
+
image: np.ndarray,
|
312 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
313 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
314 |
+
**kwargs,
|
315 |
+
) -> np.ndarray:
|
316 |
+
"""
|
317 |
+
Normalize an image. image = (image - image_mean) / image_std.
|
318 |
+
|
319 |
+
The image std is to mimic the tensorflow implementation of the `per_image_standardization`:
|
320 |
+
https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization
|
321 |
+
|
322 |
+
Args:
|
323 |
+
image (`np.ndarray`):
|
324 |
+
Image to normalize.
|
325 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
326 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
327 |
+
image is used.
|
328 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
329 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
330 |
+
"""
|
331 |
+
if image.dtype == np.uint8:
|
332 |
+
image = image.astype(np.float32)
|
333 |
+
|
334 |
+
# take mean across the whole `image`
|
335 |
+
mean = np.mean(image)
|
336 |
+
std = np.std(image)
|
337 |
+
adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape)))
|
338 |
+
|
339 |
+
return normalize(
|
340 |
+
image,
|
341 |
+
mean=mean,
|
342 |
+
std=adjusted_stddev,
|
343 |
+
data_format=data_format,
|
344 |
+
input_data_format=input_data_format,
|
345 |
+
**kwargs,
|
346 |
+
)
|
347 |
+
|
348 |
+
def preprocess(
|
349 |
+
self,
|
350 |
+
images: ImageInput,
|
351 |
+
header_text: Optional[str] = None,
|
352 |
+
do_convert_rgb: Optional[bool] = None,
|
353 |
+
do_normalize: Optional[bool] = None,
|
354 |
+
max_patches: Optional[int] = None,
|
355 |
+
patch_size: Optional[Dict[str, int]] = None,
|
356 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
357 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
358 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
359 |
+
**kwargs,
|
360 |
+
) -> ImageInput:
|
361 |
+
"""
|
362 |
+
Preprocess an image or batch of images. The processor first computes the maximum possible number of
|
363 |
+
aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the
|
364 |
+
image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the
|
365 |
+
images are standardized following the tensorflow implementation of `per_image_standardization`
|
366 |
+
(https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization).
|
367 |
+
|
368 |
+
|
369 |
+
Args:
|
370 |
+
images (`ImageInput`):
|
371 |
+
Image to preprocess. Expects a single or batch of images.
|
372 |
+
header_text (`Union[List[str], str]`, *optional*):
|
373 |
+
Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`.
|
374 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
375 |
+
Whether to convert the image to RGB.
|
376 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
377 |
+
Whether to normalize the image.
|
378 |
+
max_patches (`int`, *optional*, defaults to `self.max_patches`):
|
379 |
+
Maximum number of patches to extract.
|
380 |
+
patch_size (`dict`, *optional*, defaults to `self.patch_size`):
|
381 |
+
Dictionary containing the patch height and width.
|
382 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
383 |
+
The type of tensors to return. Can be one of:
|
384 |
+
- Unset: Return a list of `np.ndarray`.
|
385 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
386 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
387 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
388 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
389 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
390 |
+
The channel dimension format for the output image. Can be one of:
|
391 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
392 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
393 |
+
- Unset: Use the channel dimension format of the input image.
|
394 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
395 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
396 |
+
from the input image. Can be one of:
|
397 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
398 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
399 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
400 |
+
"""
|
401 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
402 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
403 |
+
patch_size = patch_size if patch_size is not None else self.patch_size
|
404 |
+
max_patches = max_patches if max_patches is not None else self.max_patches
|
405 |
+
is_vqa = self.is_vqa
|
406 |
+
|
407 |
+
if kwargs.get("data_format", None) is not None:
|
408 |
+
raise ValueError("data_format is not an accepted input as the outputs are ")
|
409 |
+
|
410 |
+
images = make_list_of_images(images)
|
411 |
+
|
412 |
+
if not valid_images(images):
|
413 |
+
raise ValueError(
|
414 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
415 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
416 |
+
)
|
417 |
+
|
418 |
+
# PIL RGBA images are converted to RGB
|
419 |
+
if do_convert_rgb:
|
420 |
+
images = [convert_to_rgb(image) for image in images]
|
421 |
+
|
422 |
+
# All transformations expect numpy arrays.
|
423 |
+
images = [to_numpy_array(image) for image in images]
|
424 |
+
|
425 |
+
if input_data_format is None:
|
426 |
+
# We assume that all images have the same channel dimension format.
|
427 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
428 |
+
|
429 |
+
if is_vqa:
|
430 |
+
if header_text is None:
|
431 |
+
raise ValueError("A header text must be provided for VQA models.")
|
432 |
+
font_bytes = kwargs.pop("font_bytes", None)
|
433 |
+
font_path = kwargs.pop("font_path", None)
|
434 |
+
|
435 |
+
if isinstance(header_text, str):
|
436 |
+
header_text = [header_text] * len(images)
|
437 |
+
|
438 |
+
images = [
|
439 |
+
render_header(image, header_text[i], font_bytes=font_bytes, font_path=font_path)
|
440 |
+
for i, image in enumerate(images)
|
441 |
+
]
|
442 |
+
|
443 |
+
if do_normalize:
|
444 |
+
images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
|
445 |
+
|
446 |
+
# convert to torch tensor and permute
|
447 |
+
images = [
|
448 |
+
self.extract_flattened_patches(
|
449 |
+
image=image, max_patches=max_patches, patch_size=patch_size, input_data_format=input_data_format
|
450 |
+
)
|
451 |
+
for image in images
|
452 |
+
]
|
453 |
+
|
454 |
+
# create attention mask in numpy
|
455 |
+
attention_masks = [(image.sum(axis=-1) != 0).astype(np.float32) for image in images]
|
456 |
+
|
457 |
+
encoded_outputs = BatchFeature(
|
458 |
+
data={"flattened_patches": images, "attention_mask": attention_masks}, tensor_type=return_tensors
|
459 |
+
)
|
460 |
+
|
461 |
+
return encoded_outputs
|
462 |
+
|
463 |
+
|
464 |
+
__all__ = ["Pix2StructImageProcessor"]
|
docs/transformers/build/lib/transformers/models/pixtral/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import TYPE_CHECKING
|
15 |
+
|
16 |
+
from ...utils import _LazyModule
|
17 |
+
from ...utils.import_utils import define_import_structure
|
18 |
+
|
19 |
+
|
20 |
+
if TYPE_CHECKING:
|
21 |
+
from .configuration_pixtral import *
|
22 |
+
from .image_processing_pixtral import *
|
23 |
+
from .image_processing_pixtral_fast import *
|
24 |
+
from .modeling_pixtral import *
|
25 |
+
from .processing_pixtral import *
|
26 |
+
else:
|
27 |
+
import sys
|
28 |
+
|
29 |
+
_file = globals()["__file__"]
|
30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/pixtral/configuration_pixtral.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Pixtral model configuration"""
|
15 |
+
|
16 |
+
from ...configuration_utils import PretrainedConfig
|
17 |
+
from ...utils import logging
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class PixtralVisionConfig(PretrainedConfig):
|
24 |
+
r"""
|
25 |
+
This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an
|
26 |
+
Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
27 |
+
with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B.
|
28 |
+
|
29 |
+
e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)
|
30 |
+
|
31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
32 |
+
documentation from [`PretrainedConfig`] for more information.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
36 |
+
Dimension of the hidden representations.
|
37 |
+
intermediate_size (`int`, *optional*, defaults to 4096):
|
38 |
+
Dimension of the MLP representations.
|
39 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
40 |
+
Number of hidden layers in the Transformer encoder.
|
41 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
42 |
+
Number of attention heads in the Transformer encoder.
|
43 |
+
num_channels (`int`, *optional*, defaults to 3):
|
44 |
+
Number of input channels in the input images.
|
45 |
+
image_size (`int`, *optional*, defaults to 1024):
|
46 |
+
Max dimension of the input images.
|
47 |
+
patch_size (`int`, *optional*, defaults to 16):
|
48 |
+
Size of the image patches.
|
49 |
+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
50 |
+
Activation function used in the hidden layers.
|
51 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
52 |
+
Dropout probability for the attention layers.
|
53 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
54 |
+
The base period of the RoPE embeddings.
|
55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
57 |
+
|
58 |
+
Example:
|
59 |
+
|
60 |
+
```python
|
61 |
+
>>> from transformers import PixtralVisionModel, PixtralVisionConfig
|
62 |
+
|
63 |
+
>>> # Initializing a Pixtral-12B style configuration
|
64 |
+
>>> config = PixtralVisionConfig()
|
65 |
+
|
66 |
+
>>> # Initializing a model (with randomly initialized weights) from the configuration
|
67 |
+
>>> model = PixtralVisionModel(configuration)
|
68 |
+
|
69 |
+
>>> # Accessing the model configuration
|
70 |
+
>>> configuration = model.config
|
71 |
+
```"""
|
72 |
+
|
73 |
+
model_type = "pixtral"
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
hidden_size=1024,
|
78 |
+
intermediate_size=4096,
|
79 |
+
num_hidden_layers=24,
|
80 |
+
num_attention_heads=16,
|
81 |
+
num_channels=3,
|
82 |
+
image_size=1024,
|
83 |
+
patch_size=16,
|
84 |
+
hidden_act="gelu",
|
85 |
+
attention_dropout=0.0,
|
86 |
+
rope_theta=10000.0,
|
87 |
+
initializer_range=0.02,
|
88 |
+
**kwargs,
|
89 |
+
):
|
90 |
+
super().__init__(**kwargs)
|
91 |
+
|
92 |
+
self.hidden_size = hidden_size
|
93 |
+
self.intermediate_size = intermediate_size
|
94 |
+
self.num_hidden_layers = num_hidden_layers
|
95 |
+
self.num_attention_heads = num_attention_heads
|
96 |
+
self.num_channels = num_channels
|
97 |
+
self.patch_size = patch_size
|
98 |
+
self.image_size = image_size
|
99 |
+
self.attention_dropout = attention_dropout
|
100 |
+
self.hidden_act = hidden_act
|
101 |
+
self.rope_theta = rope_theta
|
102 |
+
self.head_dim = hidden_size // num_attention_heads
|
103 |
+
self.initializer_range = initializer_range
|
104 |
+
|
105 |
+
|
106 |
+
__all__ = ["PixtralVisionConfig"]
|
docs/transformers/build/lib/transformers/models/pixtral/image_processing_pixtral.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Image processor class for Pixtral."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from typing import Dict, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
23 |
+
from ...image_transforms import (
|
24 |
+
pad,
|
25 |
+
resize,
|
26 |
+
to_channel_dimension_format,
|
27 |
+
)
|
28 |
+
from ...image_utils import (
|
29 |
+
ChannelDimension,
|
30 |
+
ImageInput,
|
31 |
+
PILImageResampling,
|
32 |
+
get_image_size,
|
33 |
+
infer_channel_dimension_format,
|
34 |
+
is_scaled_image,
|
35 |
+
make_list_of_images,
|
36 |
+
to_numpy_array,
|
37 |
+
valid_images,
|
38 |
+
validate_kwargs,
|
39 |
+
validate_preprocess_arguments,
|
40 |
+
)
|
41 |
+
from ...utils import TensorType, is_vision_available, logging
|
42 |
+
from ...utils.import_utils import requires_backends
|
43 |
+
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
|
48 |
+
if is_vision_available():
|
49 |
+
import PIL
|
50 |
+
|
51 |
+
|
52 |
+
# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
|
53 |
+
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
54 |
+
"""
|
55 |
+
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
56 |
+
as is.
|
57 |
+
Args:
|
58 |
+
image (Image):
|
59 |
+
The image to convert.
|
60 |
+
"""
|
61 |
+
requires_backends(convert_to_rgb, ["vision"])
|
62 |
+
|
63 |
+
if not isinstance(image, PIL.Image.Image):
|
64 |
+
return image
|
65 |
+
|
66 |
+
if image.mode == "RGB":
|
67 |
+
return image
|
68 |
+
|
69 |
+
# First we convert to RGBA to set background to white.
|
70 |
+
image = image.convert("RGBA")
|
71 |
+
|
72 |
+
# Create a new image with a white background.
|
73 |
+
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
|
74 |
+
new_image.paste(image, (0, 0), image)
|
75 |
+
new_image = new_image.convert("RGB")
|
76 |
+
return new_image
|
77 |
+
|
78 |
+
|
79 |
+
def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int:
|
80 |
+
"""
|
81 |
+
Calculate the number of image tokens given the image size and patch size.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
image_size (`Tuple[int, int]`):
|
85 |
+
The size of the image as `(height, width)`.
|
86 |
+
patch_size (`Tuple[int, int]`):
|
87 |
+
The patch size as `(height, width)`.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
`int`: The number of image tokens.
|
91 |
+
"""
|
92 |
+
height, width = image_size
|
93 |
+
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
|
94 |
+
num_width_tokens = (width - 1) // patch_width + 1
|
95 |
+
num_height_tokens = (height - 1) // patch_height + 1
|
96 |
+
return num_height_tokens, num_width_tokens
|
97 |
+
|
98 |
+
|
99 |
+
def get_resize_output_image_size(
|
100 |
+
input_image: ImageInput,
|
101 |
+
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
102 |
+
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
103 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
104 |
+
) -> tuple:
|
105 |
+
"""
|
106 |
+
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
107 |
+
size.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
input_image (`ImageInput`):
|
111 |
+
The image to resize.
|
112 |
+
size (`int` or `Tuple[int, int]`):
|
113 |
+
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
|
114 |
+
patch_size (`int` or `Tuple[int, int]`):
|
115 |
+
The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)`
|
116 |
+
will be used
|
117 |
+
input_data_format (`ChannelDimension`, *optional*):
|
118 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
`tuple`: The target (height, width) dimension of the output image after resizing.
|
122 |
+
"""
|
123 |
+
max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size)
|
124 |
+
patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
|
125 |
+
height, width = get_image_size(input_image, input_data_format)
|
126 |
+
|
127 |
+
ratio = max(height / max_height, width / max_width)
|
128 |
+
|
129 |
+
if ratio > 1:
|
130 |
+
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
131 |
+
# Here we use floor to ensure the image is always smaller than the given "longest_edge"
|
132 |
+
height = int(math.floor(height / ratio))
|
133 |
+
width = int(math.floor(width / ratio))
|
134 |
+
|
135 |
+
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
136 |
+
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
137 |
+
|
138 |
+
|
139 |
+
class PixtralImageProcessor(BaseImageProcessor):
|
140 |
+
r"""
|
141 |
+
Constructs a Pixtral image processor.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
145 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
146 |
+
`do_resize` in the `preprocess` method.
|
147 |
+
size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
|
148 |
+
Size of the maximum dimension of either the height or width dimension of the image. Used to control how
|
149 |
+
images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
|
150 |
+
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
|
151 |
+
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
|
152 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
153 |
+
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
154 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
155 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
156 |
+
the `preprocess` method.
|
157 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
158 |
+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
159 |
+
method.
|
160 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
161 |
+
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
162 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
163 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
164 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
165 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
166 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
167 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
168 |
+
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
169 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
170 |
+
Whether to convert the image to RGB.
|
171 |
+
"""
|
172 |
+
|
173 |
+
model_input_names = ["pixel_values"]
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
do_resize: bool = True,
|
178 |
+
size: Dict[str, int] = None,
|
179 |
+
patch_size: Dict[str, int] = None,
|
180 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
181 |
+
do_rescale: bool = True,
|
182 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
183 |
+
do_normalize: bool = True,
|
184 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
185 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
186 |
+
do_convert_rgb: bool = True,
|
187 |
+
**kwargs,
|
188 |
+
) -> None:
|
189 |
+
super().__init__(**kwargs)
|
190 |
+
size = size if size is not None else {"longest_edge": 1024}
|
191 |
+
patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
|
192 |
+
patch_size = get_size_dict(patch_size, default_to_square=True)
|
193 |
+
|
194 |
+
self.do_resize = do_resize
|
195 |
+
self.size = size
|
196 |
+
self.patch_size = patch_size
|
197 |
+
self.resample = resample
|
198 |
+
self.do_rescale = do_rescale
|
199 |
+
self.rescale_factor = rescale_factor
|
200 |
+
self.do_normalize = do_normalize
|
201 |
+
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
202 |
+
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
203 |
+
self.do_convert_rgb = do_convert_rgb
|
204 |
+
self._valid_processor_keys = [
|
205 |
+
"images",
|
206 |
+
"do_resize",
|
207 |
+
"size",
|
208 |
+
"patch_size",
|
209 |
+
"resample",
|
210 |
+
"do_rescale",
|
211 |
+
"rescale_factor",
|
212 |
+
"do_normalize",
|
213 |
+
"image_mean",
|
214 |
+
"image_std",
|
215 |
+
"do_convert_rgb",
|
216 |
+
"return_tensors",
|
217 |
+
"data_format",
|
218 |
+
"input_data_format",
|
219 |
+
]
|
220 |
+
|
221 |
+
def resize(
|
222 |
+
self,
|
223 |
+
image: np.ndarray,
|
224 |
+
size: Dict[str, int],
|
225 |
+
patch_size: Dict[str, int],
|
226 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
227 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
228 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
229 |
+
**kwargs,
|
230 |
+
) -> np.ndarray:
|
231 |
+
"""
|
232 |
+
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
233 |
+
resized to keep the input aspect ratio.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
image (`np.ndarray`):
|
237 |
+
Image to resize.
|
238 |
+
size (`Dict[str, int]`):
|
239 |
+
Dict containing the longest possible edge of the image.
|
240 |
+
patch_size (`Dict[str, int]`):
|
241 |
+
Patch size used to calculate the size of the output image.
|
242 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
243 |
+
Resampling filter to use when resiizing the image.
|
244 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
245 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
246 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
247 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
248 |
+
"""
|
249 |
+
if "longest_edge" in size:
|
250 |
+
size = (size["longest_edge"], size["longest_edge"])
|
251 |
+
elif "height" in size and "width" in size:
|
252 |
+
size = (size["height"], size["width"])
|
253 |
+
else:
|
254 |
+
raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
|
255 |
+
|
256 |
+
if "height" in patch_size and "width" in patch_size:
|
257 |
+
patch_size = (patch_size["height"], patch_size["width"])
|
258 |
+
else:
|
259 |
+
raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
|
260 |
+
|
261 |
+
output_size = get_resize_output_image_size(
|
262 |
+
image,
|
263 |
+
size=size,
|
264 |
+
patch_size=patch_size,
|
265 |
+
input_data_format=input_data_format,
|
266 |
+
)
|
267 |
+
return resize(
|
268 |
+
image,
|
269 |
+
size=output_size,
|
270 |
+
resample=resample,
|
271 |
+
data_format=data_format,
|
272 |
+
input_data_format=input_data_format,
|
273 |
+
**kwargs,
|
274 |
+
)
|
275 |
+
|
276 |
+
def _pad_for_batching(
|
277 |
+
self,
|
278 |
+
pixel_values: List[np.ndarray],
|
279 |
+
image_sizes: List[List[int]],
|
280 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
281 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
282 |
+
):
|
283 |
+
"""
|
284 |
+
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
285 |
+
Args:
|
286 |
+
pixel_values (`List[np.ndarray]`):
|
287 |
+
An array of pixel values of each images of shape (`batch_size`, `height`, `width`, `channels`)
|
288 |
+
image_sizes (`List[List[int]]`):
|
289 |
+
A list of sizes for each image in `pixel_values` in (height, width) format.
|
290 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
291 |
+
The channel dimension format for the output image. Can be one of:
|
292 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
293 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
294 |
+
If unset, will use same as the input image.
|
295 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
296 |
+
The channel dimension format for the input image. Can be one of:
|
297 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
298 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
299 |
+
If unset, will use the inferred format of the input image.
|
300 |
+
Returns:
|
301 |
+
List[`np.ndarray`]: The padded images.
|
302 |
+
"""
|
303 |
+
|
304 |
+
max_shape = (
|
305 |
+
max([size[0] for size in image_sizes]),
|
306 |
+
max([size[1] for size in image_sizes]),
|
307 |
+
)
|
308 |
+
pixel_values = [
|
309 |
+
pad(
|
310 |
+
image,
|
311 |
+
padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])),
|
312 |
+
data_format=data_format,
|
313 |
+
input_data_format=input_data_format,
|
314 |
+
)
|
315 |
+
for image, size in zip(pixel_values, image_sizes)
|
316 |
+
]
|
317 |
+
return pixel_values
|
318 |
+
|
319 |
+
def preprocess(
|
320 |
+
self,
|
321 |
+
images: ImageInput,
|
322 |
+
do_resize: Optional[bool] = None,
|
323 |
+
size: Dict[str, int] = None,
|
324 |
+
patch_size: Dict[str, int] = None,
|
325 |
+
resample: PILImageResampling = None,
|
326 |
+
do_rescale: Optional[bool] = None,
|
327 |
+
rescale_factor: Optional[float] = None,
|
328 |
+
do_normalize: Optional[bool] = None,
|
329 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
330 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
331 |
+
do_convert_rgb: Optional[bool] = None,
|
332 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
333 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
334 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
335 |
+
**kwargs,
|
336 |
+
) -> PIL.Image.Image:
|
337 |
+
"""
|
338 |
+
Preprocess an image or batch of images.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
images (`ImageInput`):
|
342 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
343 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
344 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
345 |
+
Whether to resize the image.
|
346 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
347 |
+
Describes the maximum input dimensions to the model.
|
348 |
+
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
349 |
+
Patch size in the model. Used to calculate the image after resizing.
|
350 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
351 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
352 |
+
has an effect if `do_resize` is set to `True`.
|
353 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
354 |
+
Whether to rescale the image.
|
355 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
356 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
357 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
358 |
+
Whether to normalize the image.
|
359 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
360 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
361 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
362 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
363 |
+
`True`.
|
364 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
365 |
+
Whether to convert the image to RGB.
|
366 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
367 |
+
The type of tensors to return. Can be one of:
|
368 |
+
- Unset: Return a list of `np.ndarray`.
|
369 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
370 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
371 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
372 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
373 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
374 |
+
The channel dimension format for the output image. Can be one of:
|
375 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
376 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
377 |
+
- Unset: Use the channel dimension format of the input image.
|
378 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
379 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
380 |
+
from the input image. Can be one of:
|
381 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
382 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
383 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
384 |
+
"""
|
385 |
+
patch_size = patch_size if patch_size is not None else self.patch_size
|
386 |
+
patch_size = get_size_dict(patch_size, default_to_square=True)
|
387 |
+
|
388 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
389 |
+
size = size if size is not None else self.size
|
390 |
+
resample = resample if resample is not None else self.resample
|
391 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
392 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
393 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
394 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
395 |
+
image_std = image_std if image_std is not None else self.image_std
|
396 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
397 |
+
|
398 |
+
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
399 |
+
|
400 |
+
images = make_list_of_images(images)
|
401 |
+
|
402 |
+
if not valid_images(images[0]):
|
403 |
+
raise ValueError(
|
404 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
405 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
406 |
+
)
|
407 |
+
|
408 |
+
validate_preprocess_arguments(
|
409 |
+
do_rescale=do_rescale,
|
410 |
+
rescale_factor=rescale_factor,
|
411 |
+
do_normalize=do_normalize,
|
412 |
+
image_mean=image_mean,
|
413 |
+
image_std=image_std,
|
414 |
+
do_resize=do_resize,
|
415 |
+
size=size,
|
416 |
+
resample=resample,
|
417 |
+
)
|
418 |
+
|
419 |
+
if do_convert_rgb:
|
420 |
+
images = [convert_to_rgb(image) for image in images]
|
421 |
+
|
422 |
+
# All transformations expect numpy arrays.
|
423 |
+
images = [to_numpy_array(image) for image in images]
|
424 |
+
|
425 |
+
if do_rescale and is_scaled_image(images[0]):
|
426 |
+
logger.warning_once(
|
427 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
428 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
429 |
+
)
|
430 |
+
|
431 |
+
if input_data_format is None:
|
432 |
+
# We assume that all images have the same channel dimension format.
|
433 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
434 |
+
|
435 |
+
batch_images = []
|
436 |
+
batch_image_sizes = []
|
437 |
+
for image in images:
|
438 |
+
if do_resize:
|
439 |
+
image = self.resize(
|
440 |
+
image=image,
|
441 |
+
size=size,
|
442 |
+
patch_size=patch_size,
|
443 |
+
resample=resample,
|
444 |
+
input_data_format=input_data_format,
|
445 |
+
)
|
446 |
+
|
447 |
+
if do_rescale:
|
448 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
449 |
+
|
450 |
+
if do_normalize:
|
451 |
+
image = self.normalize(
|
452 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
453 |
+
)
|
454 |
+
|
455 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
456 |
+
|
457 |
+
batch_images.append(image)
|
458 |
+
batch_image_sizes.append(get_image_size(image, data_format))
|
459 |
+
|
460 |
+
pixel_values = self._pad_for_batching(
|
461 |
+
pixel_values=batch_images,
|
462 |
+
image_sizes=batch_image_sizes,
|
463 |
+
input_data_format=data_format,
|
464 |
+
data_format=data_format,
|
465 |
+
)
|
466 |
+
|
467 |
+
return BatchFeature(
|
468 |
+
data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors
|
469 |
+
)
|
470 |
+
|
471 |
+
|
472 |
+
__all__ = ["PixtralImageProcessor"]
|
docs/transformers/build/lib/transformers/models/pixtral/modeling_pixtral.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch Pixtral model."""
|
16 |
+
|
17 |
+
from typing import Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
from ... import PreTrainedModel
|
24 |
+
from ...activations import ACT2FN
|
25 |
+
from ...modeling_outputs import BaseModelOutput
|
26 |
+
from ...modeling_rope_utils import dynamic_rope_update
|
27 |
+
from ...utils import (
|
28 |
+
add_start_docstrings,
|
29 |
+
add_start_docstrings_to_model_forward,
|
30 |
+
logging,
|
31 |
+
)
|
32 |
+
from .configuration_pixtral import PixtralVisionConfig
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
def position_ids_in_meshgrid(patch_embeds_list, max_width):
|
39 |
+
positions = []
|
40 |
+
for patch in patch_embeds_list:
|
41 |
+
height, width = patch.shape[-2:]
|
42 |
+
mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
|
43 |
+
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
|
44 |
+
ids = h_grid * max_width + v_grid
|
45 |
+
positions.append(ids[:, 0])
|
46 |
+
return torch.cat(positions)
|
47 |
+
|
48 |
+
|
49 |
+
class PixtralRotaryEmbedding(nn.Module):
|
50 |
+
"""
|
51 |
+
The key with pixtral embedding is just that you have a frequency for each pixel positions.
|
52 |
+
If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
|
53 |
+
is given by indexing the pre_computed frequency on the width and height.
|
54 |
+
|
55 |
+
What you output is of dimension (batch, height * width, dim) with dim the embed dim.
|
56 |
+
|
57 |
+
This simply means that for each image hidden state, you are going to add
|
58 |
+
a corresponding positional embedding, based on its index in the grid.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, config, device=None):
|
62 |
+
super().__init__()
|
63 |
+
self.rope_type = "default"
|
64 |
+
self.dim = config.head_dim
|
65 |
+
self.base = config.rope_theta
|
66 |
+
max_patches_per_side = config.image_size // config.patch_size
|
67 |
+
freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
68 |
+
|
69 |
+
h = torch.arange(max_patches_per_side, device=freqs.device)
|
70 |
+
w = torch.arange(max_patches_per_side, device=freqs.device)
|
71 |
+
|
72 |
+
freqs_h = torch.outer(h, freqs[::2]).float()
|
73 |
+
freqs_w = torch.outer(w, freqs[1::2]).float()
|
74 |
+
inv_freq = torch.cat(
|
75 |
+
[
|
76 |
+
freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
|
77 |
+
freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
|
78 |
+
],
|
79 |
+
dim=-1,
|
80 |
+
).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
|
81 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
82 |
+
|
83 |
+
# TODO maybe make it torch compatible later on. We can also just slice
|
84 |
+
self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False)
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
88 |
+
def forward(self, x, position_ids):
|
89 |
+
freqs = self.inv_freq[position_ids]
|
90 |
+
|
91 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
92 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
93 |
+
emb = freqs
|
94 |
+
cos = emb.cos()
|
95 |
+
sin = emb.sin()
|
96 |
+
|
97 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
98 |
+
|
99 |
+
|
100 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
101 |
+
def rotate_half(x):
|
102 |
+
"""Rotates half the hidden dims of the input."""
|
103 |
+
x1 = x[..., : x.shape[-1] // 2]
|
104 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
105 |
+
return torch.cat((-x2, x1), dim=-1)
|
106 |
+
|
107 |
+
|
108 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
109 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
q (`torch.Tensor`): The query tensor.
|
113 |
+
k (`torch.Tensor`): The key tensor.
|
114 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
115 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
116 |
+
position_ids (`torch.Tensor`, *optional*):
|
117 |
+
Deprecated and unused.
|
118 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
119 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
120 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
121 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
122 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
123 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
124 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
125 |
+
Returns:
|
126 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
127 |
+
"""
|
128 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
129 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
130 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
131 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
132 |
+
return q_embed, k_embed
|
133 |
+
|
134 |
+
|
135 |
+
class PixtralAttention(nn.Module):
|
136 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
137 |
+
|
138 |
+
def __init__(self, config):
|
139 |
+
super().__init__()
|
140 |
+
self.config = config
|
141 |
+
self.embed_dim = config.hidden_size
|
142 |
+
self.num_heads = config.num_attention_heads
|
143 |
+
self.head_dim = self.embed_dim // self.num_heads
|
144 |
+
|
145 |
+
self.scale = self.head_dim**-0.5
|
146 |
+
self.dropout = config.attention_dropout
|
147 |
+
|
148 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
149 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
150 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
151 |
+
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
152 |
+
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
hidden_states: torch.Tensor,
|
156 |
+
attention_mask: Optional[torch.Tensor] = None,
|
157 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
158 |
+
output_attentions: Optional[bool] = False,
|
159 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
160 |
+
"""Input shape: Batch x Time x Channel"""
|
161 |
+
|
162 |
+
batch_size, patches, _ = hidden_states.size()
|
163 |
+
|
164 |
+
query_states = self.q_proj(hidden_states)
|
165 |
+
key_states = self.k_proj(hidden_states)
|
166 |
+
value_states = self.v_proj(hidden_states)
|
167 |
+
|
168 |
+
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
169 |
+
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
170 |
+
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
171 |
+
|
172 |
+
cos, sin = position_embeddings
|
173 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
|
174 |
+
|
175 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
176 |
+
|
177 |
+
if attention_mask is not None:
|
178 |
+
attn_weights = attn_weights + attention_mask
|
179 |
+
|
180 |
+
# upcast attention to fp32
|
181 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
182 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
183 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
184 |
+
|
185 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
186 |
+
attn_output = attn_output.reshape(batch_size, patches, -1)
|
187 |
+
|
188 |
+
attn_output = self.o_proj(attn_output)
|
189 |
+
|
190 |
+
return attn_output, attn_weights
|
191 |
+
|
192 |
+
|
193 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
|
194 |
+
class PixtralMLP(nn.Module):
|
195 |
+
def __init__(self, config):
|
196 |
+
super().__init__()
|
197 |
+
self.config = config
|
198 |
+
self.hidden_size = config.hidden_size
|
199 |
+
self.intermediate_size = config.intermediate_size
|
200 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
201 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
202 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
203 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
207 |
+
return down_proj
|
208 |
+
|
209 |
+
|
210 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
|
211 |
+
class PixtralRMSNorm(nn.Module):
|
212 |
+
def __init__(self, hidden_size, eps=1e-6):
|
213 |
+
"""
|
214 |
+
PixtralRMSNorm is equivalent to T5LayerNorm
|
215 |
+
"""
|
216 |
+
super().__init__()
|
217 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
218 |
+
self.variance_epsilon = eps
|
219 |
+
|
220 |
+
def forward(self, hidden_states):
|
221 |
+
input_dtype = hidden_states.dtype
|
222 |
+
hidden_states = hidden_states.to(torch.float32)
|
223 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
224 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
225 |
+
return self.weight * hidden_states.to(input_dtype)
|
226 |
+
|
227 |
+
def extra_repr(self):
|
228 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
229 |
+
|
230 |
+
|
231 |
+
class PixtralAttentionLayer(nn.Module):
|
232 |
+
def __init__(self, config):
|
233 |
+
super().__init__()
|
234 |
+
self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
235 |
+
self.feed_forward = PixtralMLP(config)
|
236 |
+
self.attention = PixtralAttention(config)
|
237 |
+
self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
238 |
+
|
239 |
+
def forward(
|
240 |
+
self,
|
241 |
+
hidden_states: torch.Tensor,
|
242 |
+
attention_mask: torch.Tensor,
|
243 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
244 |
+
output_attentions: Optional[bool] = None,
|
245 |
+
) -> Tuple[torch.FloatTensor]:
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
hidden_states (`torch.FloatTensor`):
|
249 |
+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
250 |
+
attention_mask (`torch.FloatTensor`):
|
251 |
+
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
252 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
253 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
254 |
+
returned tensors for more detail.
|
255 |
+
"""
|
256 |
+
residual = hidden_states
|
257 |
+
|
258 |
+
hidden_states = self.attention_norm(hidden_states)
|
259 |
+
hidden_states, attn_weights = self.attention(
|
260 |
+
hidden_states=hidden_states,
|
261 |
+
attention_mask=attention_mask,
|
262 |
+
position_embeddings=position_embeddings,
|
263 |
+
output_attentions=output_attentions,
|
264 |
+
)
|
265 |
+
hidden_states = residual + hidden_states
|
266 |
+
|
267 |
+
residual = hidden_states
|
268 |
+
hidden_states = self.ffn_norm(hidden_states)
|
269 |
+
hidden_states = self.feed_forward(hidden_states)
|
270 |
+
hidden_states = residual + hidden_states
|
271 |
+
|
272 |
+
outputs = (hidden_states,)
|
273 |
+
|
274 |
+
if output_attentions:
|
275 |
+
outputs += (attn_weights,)
|
276 |
+
return outputs
|
277 |
+
|
278 |
+
|
279 |
+
class PixtralTransformer(nn.Module):
|
280 |
+
def __init__(self, config):
|
281 |
+
super().__init__()
|
282 |
+
self.config = config
|
283 |
+
self.layers = torch.nn.ModuleList()
|
284 |
+
for _ in range(config.num_hidden_layers):
|
285 |
+
self.layers.append(PixtralAttentionLayer(config))
|
286 |
+
self.gradient_checkpointing = False
|
287 |
+
|
288 |
+
def forward(
|
289 |
+
self,
|
290 |
+
inputs_embeds,
|
291 |
+
attention_mask: Optional[torch.Tensor] = None,
|
292 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
293 |
+
output_attentions: Optional[bool] = None,
|
294 |
+
output_hidden_states: Optional[bool] = None,
|
295 |
+
return_dict: Optional[bool] = None,
|
296 |
+
) -> Union[Tuple, BaseModelOutput]:
|
297 |
+
r"""
|
298 |
+
Args:
|
299 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
300 |
+
Embeddings which serve as input to the Transformer.
|
301 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
302 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
303 |
+
|
304 |
+
- 1 for tokens that are **not masked**,
|
305 |
+
- 0 for tokens that are **masked**.
|
306 |
+
|
307 |
+
[What are attention masks?](../glossary#attention-mask)
|
308 |
+
output_attentions (`bool`, *optional*):
|
309 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
310 |
+
returned tensors for more detail.
|
311 |
+
output_hidden_states (`bool`, *optional*):
|
312 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
313 |
+
for more detail.
|
314 |
+
return_dict (`bool`, *optional*):
|
315 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
316 |
+
"""
|
317 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
318 |
+
output_hidden_states = (
|
319 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
320 |
+
)
|
321 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
322 |
+
|
323 |
+
encoder_states = () if output_hidden_states else None
|
324 |
+
all_attentions = () if output_attentions else None
|
325 |
+
|
326 |
+
hidden_states = inputs_embeds
|
327 |
+
for encoder_layer in self.layers:
|
328 |
+
if output_hidden_states:
|
329 |
+
encoder_states = encoder_states + (hidden_states,)
|
330 |
+
if self.gradient_checkpointing and self.training:
|
331 |
+
layer_outputs = self._gradient_checkpointing_func(
|
332 |
+
encoder_layer.__call__,
|
333 |
+
hidden_states,
|
334 |
+
attention_mask,
|
335 |
+
position_embeddings,
|
336 |
+
output_attentions,
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
layer_outputs = encoder_layer(
|
340 |
+
hidden_states,
|
341 |
+
attention_mask,
|
342 |
+
position_embeddings=position_embeddings,
|
343 |
+
output_attentions=output_attentions,
|
344 |
+
)
|
345 |
+
|
346 |
+
hidden_states = layer_outputs[0]
|
347 |
+
|
348 |
+
if output_attentions:
|
349 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
350 |
+
|
351 |
+
if output_hidden_states:
|
352 |
+
encoder_states = encoder_states + (hidden_states,)
|
353 |
+
|
354 |
+
if not return_dict:
|
355 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
356 |
+
return BaseModelOutput(
|
357 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
358 |
+
)
|
359 |
+
|
360 |
+
|
361 |
+
PIXTRAL_START_DOCSTRING = r"""
|
362 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
363 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
364 |
+
etc.)
|
365 |
+
|
366 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
367 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
368 |
+
and behavior.
|
369 |
+
|
370 |
+
Parameters:
|
371 |
+
config ([`PixtralVisionConfig`]):
|
372 |
+
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
|
373 |
+
load the weights associated with the model, only the configuration. Check out the
|
374 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
375 |
+
"""
|
376 |
+
|
377 |
+
|
378 |
+
class PixtralPreTrainedModel(PreTrainedModel):
|
379 |
+
config_class = PixtralVisionConfig
|
380 |
+
base_model_prefix = "model"
|
381 |
+
main_input_name = "pixel_values"
|
382 |
+
supports_gradient_checkpointing = True
|
383 |
+
_no_split_modules = ["PixtralAttentionLayer"]
|
384 |
+
|
385 |
+
def _init_weights(self, module):
|
386 |
+
std = self.config.initializer_range
|
387 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
388 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
389 |
+
if module.bias is not None:
|
390 |
+
module.bias.data.zero_()
|
391 |
+
elif isinstance(module, PixtralRMSNorm):
|
392 |
+
module.weight.data.fill_(1.0)
|
393 |
+
|
394 |
+
|
395 |
+
PIXTRAL_INPUTS_DOCSTRING = r"""
|
396 |
+
Args:
|
397 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
398 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`]
|
399 |
+
for details.
|
400 |
+
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
|
401 |
+
The sizes of the images in the batch, being (height, width) for each image.
|
402 |
+
output_attentions (`bool`, *optional*):
|
403 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
404 |
+
tensors for more detail.
|
405 |
+
output_hidden_states (`bool`, *optional*):
|
406 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
407 |
+
more detail.
|
408 |
+
return_dict (`bool`, *optional*):
|
409 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
410 |
+
"""
|
411 |
+
|
412 |
+
|
413 |
+
def generate_block_attention_mask(patch_embeds_list, tensor):
|
414 |
+
dtype = tensor.dtype
|
415 |
+
device = tensor.device
|
416 |
+
seq_len = tensor.shape[1]
|
417 |
+
d_min = torch.finfo(dtype).min
|
418 |
+
causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
|
419 |
+
|
420 |
+
block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
|
421 |
+
block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
|
422 |
+
for start, end in zip(block_start_idx, block_end_idx):
|
423 |
+
causal_mask[start:end, start:end] = 0
|
424 |
+
|
425 |
+
causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
|
426 |
+
return causal_mask
|
427 |
+
|
428 |
+
|
429 |
+
@add_start_docstrings(
|
430 |
+
"The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.",
|
431 |
+
PIXTRAL_START_DOCSTRING,
|
432 |
+
)
|
433 |
+
class PixtralVisionModel(PixtralPreTrainedModel):
|
434 |
+
base_model_prefix = "vision_encoder"
|
435 |
+
|
436 |
+
def __init__(self, config):
|
437 |
+
super().__init__(config)
|
438 |
+
self.config = config
|
439 |
+
self.patch_conv = nn.Conv2d(
|
440 |
+
in_channels=config.num_channels,
|
441 |
+
out_channels=config.hidden_size,
|
442 |
+
kernel_size=config.patch_size,
|
443 |
+
stride=config.patch_size,
|
444 |
+
bias=False,
|
445 |
+
)
|
446 |
+
self.patch_size = config.patch_size
|
447 |
+
self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
|
448 |
+
self.transformer = PixtralTransformer(config)
|
449 |
+
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
450 |
+
|
451 |
+
self.post_init()
|
452 |
+
|
453 |
+
def get_input_embeddings(self):
|
454 |
+
return self.patch_conv
|
455 |
+
|
456 |
+
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
|
457 |
+
def forward(
|
458 |
+
self,
|
459 |
+
pixel_values: torch.Tensor,
|
460 |
+
image_sizes: torch.Tensor,
|
461 |
+
output_hidden_states: Optional[bool] = None,
|
462 |
+
output_attentions: Optional[bool] = None,
|
463 |
+
return_dict: Optional[bool] = None,
|
464 |
+
*args,
|
465 |
+
**kwargs,
|
466 |
+
) -> Union[Tuple, BaseModelOutput]:
|
467 |
+
"""
|
468 |
+
Returns:
|
469 |
+
pixel_values: tensor of token features for
|
470 |
+
all tokens of all images of shape (N_toks, D)
|
471 |
+
"""
|
472 |
+
# pass images through initial convolution independently
|
473 |
+
patch_embeds = self.patch_conv(pixel_values)
|
474 |
+
patch_embeds_list = [
|
475 |
+
embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)]
|
476 |
+
for embed, size in zip(patch_embeds, image_sizes)
|
477 |
+
]
|
478 |
+
|
479 |
+
# flatten to a single sequence
|
480 |
+
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
|
481 |
+
patch_embeds = self.ln_pre(patch_embeds)
|
482 |
+
|
483 |
+
# positional embeddings
|
484 |
+
position_ids = position_ids_in_meshgrid(
|
485 |
+
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
486 |
+
)
|
487 |
+
position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
|
488 |
+
|
489 |
+
attention_mask = generate_block_attention_mask(
|
490 |
+
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
491 |
+
)
|
492 |
+
|
493 |
+
out = self.transformer(
|
494 |
+
patch_embeds,
|
495 |
+
attention_mask=attention_mask,
|
496 |
+
position_embeddings=position_embeddings,
|
497 |
+
output_hidden_states=output_hidden_states,
|
498 |
+
output_attentions=output_attentions,
|
499 |
+
return_dict=return_dict,
|
500 |
+
)
|
501 |
+
|
502 |
+
return out
|
503 |
+
|
504 |
+
|
505 |
+
__all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"]
|