+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
\ No newline at end of file
diff --git a/data/EveryDream/README.MD b/data/EveryDream/README.MD
new file mode 100644
index 0000000000000000000000000000000000000000..80440efe63a93317107c46da1763a37f15f2d3a6
--- /dev/null
+++ b/data/EveryDream/README.MD
@@ -0,0 +1,63 @@
+# EveryDream Tools
+
+This repo will contain tools for data engineering efforts for people interested in taking their fine tuning beyond the initial DreamBooth paper implementations for Stable Diffusion, and may be useful for other image projects.
+
+If you are looking for trainers, check out [EveryDream 2.0](https://github.com/victorchall/EveryDream2trainer). This is just a toolkit repo for data work but works in concert with that trainer.
+
+For instance with Stable Diffusion, by using ground truth Laion data mixed in with training data to replace "regularization" images, together with clip-interrogated captioning or original TEXT caption from laion, or human-geneated labels, training quality can be improved. These are a significant steps towards towards full fine tuning capabilities.
+
+Captioned training together with regularization has enabled multi-subject and multi-style training at the same time, and can scale to larger training efforts.
+
+As an example project, you can download a large scale model for Final Fantasy 7 Remake here: https://huggingface.co/panopstor/ff7r-stable-diffusion and be sure to also follow up on the gist link at the bottom for more information along with links to example output of a multi-model fine tuning.
+
+Join the EveryDream discord here: https://discord.gg/uheqxU6sXN
+
+## Tools
+
+[Download scrapes using Laion](./doc/LAION_SCRAPE.md) - Web scrapes images off the web using Laion data files (runs on CPU).
+
+[Auto Captioning](./doc/AUTO_CAPTION.md) - Uses BLIP interrogation to caption images for training (includes colab notebook, needs minimal GPU).
+
+[File renaming](./doc/FILE_RENAME.md) - Simple script for replacing generic pronouns that come out of clip in filenames with proper names (ex "a man" -> "john doe", "a person" -> "jane doe").
+
+*See clip_rename.bat for an example to chain captioning and renaming together.*
+
+[Compress images](./doc/COMPRESS_IMG.md) - Compresses images to WEBP with a given size (ex 1.5 megapixels) to reduce disk usage if you've downloaded some massive PNG data sets (ex. FFHQ) and wish to save some disk space.
+
+[Training](https://github.com/victorchall/EveryDream2trainer) (separate repo) - Fine tuning
+
+[Image Caption GUI](./doc/CAPTION_GUI.md) and [Video frame extractor](./doc/VIDEO_EXTRACTOR.md) courtesy of [MStevenson](https://github.com/mstevenson/)
+
+[General Tools Notebook](EveryDream_Tools.ipynb) Collection of various tools in this codebase by [Nawnie](https://github.com/nawnie) if you prefer to use Notebook GUI instead of the command line.
+
+## Install
+
+You can use conda or venv. This was developed on Python 3.10.5 but may work on older newer versions.
+
+One step venv setup:
+
+ create_venv.bat
+
+Don't forget to activate every time you open the command prompt later.
+
+ activate_venv.bat
+
+To use conda instead of venv:
+
+ conda env create -f environment.yaml
+
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
+
+ git clone https://github.com/salesforce/BLIP scripts/BLIP
+
+ conda activate everydream
+
+Or you if you wish to configure your own venv, container/WSL, or Linux:
+
+ pip install -r requirements.txt
+
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
+
+ git clone https://github.com/salesforce/BLIP scripts/BLIP
+
+Thanks to the SalesForce team for the [BLIP tool](https://github.com/salesforce/BLIP). It uses CLIP to produce sane sentences like you would expect to see in alt-text.
diff --git a/data/EveryDream/activate_venv.bat b/data/EveryDream/activate_venv.bat
new file mode 100644
index 0000000000000000000000000000000000000000..5a7f316bb02b8f8b430ae0ee60a3a6ccd86fec44
--- /dev/null
+++ b/data/EveryDream/activate_venv.bat
@@ -0,0 +1 @@
+call .venv/scripts/activate.bat
\ No newline at end of file
diff --git a/data/EveryDream/clip_rename.bat b/data/EveryDream/clip_rename.bat
new file mode 100644
index 0000000000000000000000000000000000000000..65c5918a6ad729ca025a8a6c638cbc09bf7d4aca
--- /dev/null
+++ b/data/EveryDream/clip_rename.bat
@@ -0,0 +1,6 @@
+python scripts/auto_caption.py --q_factor 1.4
+::python scripts/filename_replace.py --img_dir output --find "a woman" --replace "rihanna"
+::python scripts/filename_replace.py --img_dir output --find "a person" --replace "rihanna"
+::python scripts/filename_replace.py --img_dir output --find "a man" --replace "asap rocky"
+::python scripts/filename_replace.py --img_dir output --replace "Keira Knightley"
+::python scripts/filename_replace.py --img_dir output --append "by Giotto"
diff --git a/data/EveryDream/create_venv.bat b/data/EveryDream/create_venv.bat
new file mode 100644
index 0000000000000000000000000000000000000000..558e1c5e313d95d324c748f752e3bf77960920ea
--- /dev/null
+++ b/data/EveryDream/create_venv.bat
@@ -0,0 +1,15 @@
+python -m venv .venv
+call .venv/scripts/activate.bat
+if %errorlevel% neq 0 goto :error
+pip install -r requirements.txt
+pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
+git clone https://github.com/salesforce/BLIP scripts/BLIP
+if %errorlevel% neq 0 goto :error
+
+goto :done
+
+:error
+echo Error occurred trying to install or activate venv.
+exit /b %errorlevel%
+
+:done
\ No newline at end of file
diff --git a/data/EveryDream/deactivate_venv.bat b/data/EveryDream/deactivate_venv.bat
new file mode 100644
index 0000000000000000000000000000000000000000..8892f4ccc4c1003010b3e5d6aefc4af8004342a6
--- /dev/null
+++ b/data/EveryDream/deactivate_venv.bat
@@ -0,0 +1 @@
+call .venv/scripts/deactivate.bat
\ No newline at end of file
diff --git a/data/EveryDream/demo/beam_min_vs_q.webp b/data/EveryDream/demo/beam_min_vs_q.webp
new file mode 100644
index 0000000000000000000000000000000000000000..6e4a76fa9d301603be0e48cc6eed084e4fa8a4da
Binary files /dev/null and b/data/EveryDream/demo/beam_min_vs_q.webp differ
diff --git a/data/EveryDream/demo/beam_vs_nucleus.webp b/data/EveryDream/demo/beam_vs_nucleus.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5268641d833d2183e0ab8510eec8feba15c90b72
Binary files /dev/null and b/data/EveryDream/demo/beam_vs_nucleus.webp differ
diff --git a/data/EveryDream/demo/beam_vs_nucleus_2.webp b/data/EveryDream/demo/beam_vs_nucleus_2.webp
new file mode 100644
index 0000000000000000000000000000000000000000..3417db7c26e470ca32b6ebeefc5e28abec5b3207
Binary files /dev/null and b/data/EveryDream/demo/beam_vs_nucleus_2.webp differ
diff --git a/data/EveryDream/demo/demo01.png b/data/EveryDream/demo/demo01.png
new file mode 100644
index 0000000000000000000000000000000000000000..1a6542c3ab0b9d89b80eb4f880070f54f40d1fa7
Binary files /dev/null and b/data/EveryDream/demo/demo01.png differ
diff --git a/data/EveryDream/demo/demo02.png b/data/EveryDream/demo/demo02.png
new file mode 100644
index 0000000000000000000000000000000000000000..a3901d19e385b23744fbeea79421d1f1bfe97b9c
Binary files /dev/null and b/data/EveryDream/demo/demo02.png differ
diff --git a/data/EveryDream/demo/demo03.png b/data/EveryDream/demo/demo03.png
new file mode 100644
index 0000000000000000000000000000000000000000..e961de225a5ff7e940a9e14ca338ddcfb10951f4
Binary files /dev/null and b/data/EveryDream/demo/demo03.png differ
diff --git a/data/EveryDream/demo/output_zip.png b/data/EveryDream/demo/output_zip.png
new file mode 100644
index 0000000000000000000000000000000000000000..8579c170d537dea24017894e8462558dd98dfb83
Binary files /dev/null and b/data/EveryDream/demo/output_zip.png differ
diff --git a/data/EveryDream/demo/upload_images_caption.png b/data/EveryDream/demo/upload_images_caption.png
new file mode 100644
index 0000000000000000000000000000000000000000..29501915a07c5f007090b11e850f37e1de0e1c96
Binary files /dev/null and b/data/EveryDream/demo/upload_images_caption.png differ
diff --git a/data/EveryDream/doc/AUTO_CAPTION.md b/data/EveryDream/doc/AUTO_CAPTION.md
new file mode 100644
index 0000000000000000000000000000000000000000..6510a716f8b8aba812ce0242766b0a3e94ce90b6
--- /dev/null
+++ b/data/EveryDream/doc/AUTO_CAPTION.md
@@ -0,0 +1,115 @@
+# Automatic captioning
+
+Automatic captioning uses Salesforce's BLIP to automatically create a clean sentence structure for captioning input images before training.
+
+By default this requires an Nvidia GPU, but is not terribly intensive work. It should run fine on something like a 1050 Ti 4GB. You can even run this on the CPU by specifying `--torch_device cpu` as an argument. This will be slower than running on a Nvidia GPU, but will work even on Apple Silicon Macs.
+
+[EveryDream trainer](https://github.com/victorchall/EveryDream-trainer) no longer requires cropped images. You only need to crop to exclude stuff you don't want trained, or to improve the portion of face close ups in your data. The EveryDream trainer now accepts multiple aspect ratios and can train on them natively.
+
+But if you do wish to crop for other trainers, you can use [Birme](https://www.birme.net/?target_width=512&target_height=512&auto_focal=false&image_format=webp&quality_jpeg=95&quality_webp=99) to crop and resize first. There are various tools out there for this.
+
+
+
+## Execute
+
+Place input files into the /input folder
+
+ python scripts/auto_caption.py
+
+Files will be **copied** and renamed to the caption as the file name and placed into /output.
+
+## Colab notebook
+
+This will run quite well on a T4 instance on Google Colab. Don't waste credits on more powerful GPUs.
+
+https://colab.research.google.com/github/victorchall/EveryDream/blob/main/AutoCaption.ipynb
+
+It should work on other GPU providers on minimal power Nvidia GPU instances, but you are on your own to upload and download files.
+
+## Additional command line args:
+
+### --img_dir
+
+Changes the default input directory to read for files. Default is /input
+
+ python scripts/auto_caption.py --img_dir x:/data/my_cropped_images
+
+### --out_dir
+
+Changes the default output directory. Default is /output
+
+ python scripts/auto_caption.py --out_dir x:/data/ready_to_train
+
+### --format
+
+The default behavior will simply name the file the caption .EXT and, if needed, add _n at the end to avoid collisions, for use with EveryDream trainer or Kane Wallmann's dream booth fork.
+
+ex output: *"a man in a blue suit and a woman in a black dress standing next to each other in front of a table with a potted plant on it.jpg"*
+
+"mrwho" or "joepenna" will add \[number\]@ as a prefix for use with MrWho's captioning system (on JoePenna dream both fork) which uses that naming standard to avoid file name collisions.
+
+ python scripts/auto_caption.py --format "mrwho"
+
+"txt" or "caption" will create a ".txt" or ".caption" file instead of renaming the image. ".txt" sidecar is another option for EveryDream trainer instead of getting the caption from the filename itself, and ".caption" is an option for other trainers.
+
+ python scripts/auto_caption.py --format "txt"
+
+or
+
+ python scripts/auto_caption.py --format "caption"
+## Tweaks
+
+You may find the following setting useful to deal with issues with bad auto-captioning. Start with defaults, and if you have issues with captions that seem inaccurate or reptitious try some of the following settings.
+
+### --nucleus
+
+Uses an alternative "nucleus" algorithm instead of the default "beam 16" algorithm. Nucleus produces relatively short captions but reliably absent of repeated words and phrases, comparable to using beam 16 which can be adjusted further but may need more tweaking.
+
+
+ python scripts/auto_caption.py --nucleus
+
+
+
+See q_factor below. 0.3 to 3 seem to produce sensible prompts, though 0.01 and 2000 will still work fairly well.
+
+Additional caption example for above with nucleus with different q_factor values:
+
+nucleus q_factor 9999: *"a number of kites painted in different colors in a ceiling"*
+
+nucleus q_factor 200: *"a group of people waiting under art hanging from a ceiling"*
+
+nucleus q_factor 1: *"several people standing around with large colorful umbrellas"*
+
+nucleus q_factor 0.01: *"people are standing in an open building with colorful paper decorations"*
+
+nucleus q_factor 0.00001: (same as above)
+
+### --q_factor
+
+An tuning adjustment depending the algorithm used.
+
+For the default beam 16 algorithm it limits the ability of words and phrases to be repeated. Higher value reduces repeated words and phrases. 0.6-1.4 are sensible values for beam 16. Default is 1.0 and works well with the defaulted value min_length of 24. Consider using higher values if you use a min_length higher than 24 with beam 16.
+
+For nucleus (--nucleus), it simply changes the opinion on the prompt and does not impact repeats. Values ranging from 0.01 to 200 seem sensible and default of 1.0 usually works well.
+
+
+
+### --min_length
+
+Adjusts the minimum length of prompt, measured in tokens. **Only applies to beam 16.** Useful to adjust along with --q_factor to keep it from repeating.
+
+Default is 22. Sensible values are 15 to 30, max is 48. Larger values are much more prone to repeating phrases and should be accompanied by increasing --q_factor to avoid repeats.
+
+ python scripts/auto_caption.py --min_length 20
+
+ python scripts/auto_caption.py --min_length 34 --q_factor 1.4
+
+
+
+### Note
+
+If you continue to both increase min_length and q_factor with default beam algorithm in an attempt to get a really long caption without repeats it will generate oddly specific prompts. For example using the above image:
+
+--q_factor 1.9 --min_length 48:
+
+*"a painting of a group of people sitting at a table in a room with red drapes on the walls and gold trimmings on the ceiling, while one person is holding a wine glass in front of the other hand"*
diff --git a/data/EveryDream/doc/CAPTION_GUI.md b/data/EveryDream/doc/CAPTION_GUI.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6a9770736a424168b69c92ce72369238e6ba984
--- /dev/null
+++ b/data/EveryDream/doc/CAPTION_GUI.md
@@ -0,0 +1,20 @@
+# MStevenson's tools
+
+## Caption GUI
+
+ python scripts/image_caption_gui.py
+
+Python GUI tool to manually caption images for machine learning.
+
+A sidecar file is created for each image with the same name and a .txt extension. These are compatible with EveryDream Trainer.
+
+### Controls:
+[control/command + o] to open a folder of images.
+
+[page down] and [page up] to go to next and previous images. Hold shift to skip 10 images.
+
+[shift + home] and [shift + end] to go to first and last images.
+
+[shift + delete] to move the current image into a '_deleted' folder.
+
+[escape] to exit the app.
\ No newline at end of file
diff --git a/data/EveryDream/doc/COMPRESS_IMG.md b/data/EveryDream/doc/COMPRESS_IMG.md
new file mode 100644
index 0000000000000000000000000000000000000000..ddcb0d7b985ba7e817823fe7a5c094aff4f55b4c
--- /dev/null
+++ b/data/EveryDream/doc/COMPRESS_IMG.md
@@ -0,0 +1,69 @@
+# Mass compressing images in a folder
+
+## How it Works
+
+This script will sweep a folder and compress all the images to a given total number of megapixels. Aspect ratio is not changed and nothing is cropped.
+
+*Images will only be resized if they exceed the specified megapixel limit.* Images within the limit will not be resized.
+
+This script will also correct issues with images having EXIF directives that rotate the image. Or, in other words, it will make sure the proper orientation is saved native to the output image as trainers may not respect EXIF rotation directives.
+
+EXIF rotation correction will take place regardless of whether images are resized.
+
+Note this script will not attempt to copy or move any ICC color profiles at this time. Trainers likely do not respect this anyway...
+
+Defaults are 1.5 megapixels, output is WEBPb at "quality 95" which affects the compression ratio. 90-99 are sane values for quality. For the purposes of training stable diffusion, 1.5 megapixels is a good balance between quality and file size, and quality of 90-99 is as well.
+
+If you are hoping to use massive training files in the future as tech advances, you may wish to change the --max_mp setting to a higher value, but for now 1.5MP is more than enough to last for a few more advances in the technology. Ultimately this is your choice. EveryDream trainer is built to handle multiple aspects, but if you want to use images for another trainer and will crop square, you may wish to use a higher value to make sure the images remain large after croppy, or consider cropping carefully first *before* running this script.
+
+## Usage
+
+ usage: compress_img.py [-h] [--img_dir IMG_DIR] [--out_dir OUT_DIR]
+ [--max_mp MAX_MP] [--quality QUALITY] [--overwrite]
+ [--noresize] [--delete]
+
+ Compress images in a directory.
+
+ options:
+ -h, --help show this help message and exit
+ --img_dir IMG_DIR path to image directory (default: 'input')
+ --out_dir OUT_DIR path to output directory (default: IMG_DIR)
+ --max_mp MAX_MP maximum megapixels (default: 1.5)
+ --quality QUALITY save quality (default: 95, range: 0-100, suggested: 90+)
+ --overwrite overwrite files in output directory
+ --noresize do not resize, just fix orientation
+ --delete delete original files after processing
+
+The most basic use will load images from the local `input` directory, scale and rotate all the images, then write them back to the same folder. Default size is 1.5 megapixels.
+
+ python scripts/compress_img.py
+
+To specify the image source directory, specify the `--img_dir`:
+
+ python scripts/compress_img.py --img_dir Q:\big_images
+
+To save compressed images to a different path, specify the `--out_dir`:
+
+ python scripts/compress_img.py --img_dir Q:\big_images --out_dir Q:\small_images
+
+If a specific image already exists in the output path, **it will be skipped**. For example, if you run the script twice, existing `.webp` images in the output directory will be skipped entirely. To overwrite existing files, use the `--overwrite` directive:
+
+ python scripts/compress_img.py --img_dir Q:\big_images --overwrite
+
+If you want to ensure no files are skipped, *without overwriting existing images,* use `--out_dir` to specify an empty output folder.
+
+If you want to delete the *original source image* after it has been resized, use the `--delete` directive:
+
+ python scripts/compress_img.py --img_dir Q:\big_images --delete
+
+The `--delete` directive will not delete the original if it was overwritten or skipped.
+
+To change the max megapixels, use the `--max_mp` option. For example, to set max megapixels to 2.0 and overwrite existing images in the output directory, see this example:
+
+ python scripts/compress_img.py --img_dir Q:\big_images --out_dir Q:\small_images --max_mp 2.0 --overwrite
+
+Once you are comfortable with what is going on and OK with removing original images, you can use this to just replace everything in-place (these are my preferred settings):
+
+ python scripts/compress_img.py --img_dir Q:\big_images --max_mp 1.5 --quality 99 --overwrite --delete
+
+This will compress all images in the `Q:\big_images` directory down to a maximum `1.5` megapixels, at quality `99`, and will `overwrite` any existing output images and `delete` the original, un-altered image.
diff --git a/data/EveryDream/doc/FILE_RENAME.md b/data/EveryDream/doc/FILE_RENAME.md
new file mode 100644
index 0000000000000000000000000000000000000000..7c9a5c8013dc1762641c3e75ec5c16d4267b80eb
--- /dev/null
+++ b/data/EveryDream/doc/FILE_RENAME.md
@@ -0,0 +1,36 @@
+# Filename Replace
+
+This is a very simple script to rename generic pronouns in files to proper names after using auto captioning. This script does not create copies. It renames the files in place.
+
+By default, it will replace "a man", "a woman", and "a person" with your supplied proper name. This works well for single subject without tweaking.
+
+
+## Usage
+
+ python scripts/filename_replace.py --img_dir output --replace "john doe"
+
+*"a man standing in a park with birds on his shoulders.jpg"
+->
+"john doe standing in a park with birds on his shoulders.jpg"*
+
+## Append tags only
+
+ python scripts/filename_replace.py --img_dir "x:\myfiles" --append_only " by claude monet"
+
+This will simply append " by claude monet" without replacing anything, useful to add tags or artstyle keywords.
+
+## Chaining with auto caption
+
+You can chain together the auto_caption.py and file_rename.py to help deal with multiple people in photos in a simple shell script (bash or windows .bat) with a bit of thinking about what you replace and using --find to specify the pronoun to replace first more specifically than all three default pronouns.
+
+ python scripts/auto_caption.py --q_factor 1.4 --img_dir input --out_dir output
+ python scripts/filename_replace.py --img_dir output --find "a woman" --replace "rihanna"
+ python scripts/filename_replace.py --img_dir output --replace "asap rocky"
+
+"a man and a woman standing next to each other in front of a green wall with leaves on it.webp"
+->
+"asap rocky and rihanna standing next to each other in front of a green wall with leaves on it.webp"
+
+See clip_rename.bat in the root folder, modify it to your needs.
+
+Renaming is nearly instant as it is just renaming the files and not using and AI models or calculations, just a dumb find and replace on the filename.
diff --git a/data/EveryDream/doc/LAION_SCRAPE.md b/data/EveryDream/doc/LAION_SCRAPE.md
new file mode 100644
index 0000000000000000000000000000000000000000..198c986b45bfd289a3cf7267adc73941ee86cdc9
--- /dev/null
+++ b/data/EveryDream/doc/LAION_SCRAPE.md
@@ -0,0 +1,51 @@
+# download_laion.py
+
+
+
+This script enables you to webscrape using the Laion parquet files which are available on Huggingface.co.
+
+It has been tested with 2B-en-aesthetic, but may need minor tweaks for some other datasets that contain different columns. Keep in mind some other files are purely sidecar metadata.
+
+https://huggingface.co/datasets/laion/laion2B-en-aesthetic
+
+**This tool does not work unless you download a set of Laion parquet files, above link is suggested.** Download all 128 .parquet files and place them in the /laion folder.
+
+The script will rename downloaded files to the best of its ability to the TEXT (caption) of the image with the original file extension, which can be plugged into the new class of caption-capable DreamBooth apps or the EveryDream trainer that will use the filename as the prompt for training.
+
+One suggested use is to take this data and replace regularization images with ground truth data from the Laion dataset.
+
+It should execute quite quickly as it uses async task gathers for the the HTTP and fileio work.
+
+Default folders are /laion for the parquet files and /output for downloaded images relative to the root folder, but consider disk space and point to another location if needed.
+
+## Examples
+
+Query all the parquet files in ./laion for any image with a caption (TEXT) containing "a man" and attempt top stop after downloading (approximately) 50 files:
+
+ python scripts/download_laion.py --search_text "a man" --limit 50
+
+Query for person with a leading and trailing space:
+
+ python scripts/download_laion.py --search_text " person " --limit 200
+
+Query for both "man" and "photo" anywhere in the caption, and write them to z:/myDumpFolder instead of the default folder. Useful if you need to put them on another drive, NAS, etc. The default limit of 100 images will apply since --limit is omitted:
+
+ python scripts/download_laion.py --search_text "man,photo" --out_dir "z:/myDumpFolder" --laion_dir "x:/datahoard/laion5b"
+
+## Performance
+
+Script should be reasonably fast depending on your internet speed. I'm able to pull 10,000 images in about 3 1/2 minutes on 1 Gbit fiber.
+
+## Other resources
+
+Nvidia has compiled a close up photo set: [ffhq-dataset](https://github.com/NVlabs/ffhq-dataset)
+
+## Batch run
+
+You can throw commands in a shell/cmd script to run several searches, but I will leave this exercise to the user.
+
+ python scripts/download_laion.py --search_text "jan van eyck" --limit 200
+ python scripts/download_laion.py --search_text " hokusai" --limit 200
+ python scripts/download_laion.py --search_text " bernini" --limit 200
+ python scripts/download_laion.py --search_text "Gustav Klimt" --limit 200
+ python scripts/download_laion.py --search_text "engon Schiele" --limit 200
diff --git a/data/EveryDream/doc/VIDEO_EXTRACTOR.md b/data/EveryDream/doc/VIDEO_EXTRACTOR.md
new file mode 100644
index 0000000000000000000000000000000000000000..43afafe475ac7ca73e835347dc5aba6e253a6b1e
--- /dev/null
+++ b/data/EveryDream/doc/VIDEO_EXTRACTOR.md
@@ -0,0 +1,27 @@
+# Video frame extractor
+
+## Usage
+
+Place video files into the top level of a directory.
+
+Execute `python scripts/extract_video_frames.py --vid_dir path/to/videos` to iterate over all files, extract frames at regular intervals, and save full resolution frame images to disk.
+
+This tool supports a wide variety of input video containers and codecs (via OpenCV), and exports jpg or png files.
+
+## Arguments
+
+### --vid_dir
+
+Required directory path for input video files.
+
+### --out_dir
+
+Optional directory path in which to store extracted frame images. Defaults to a directory named 'output' that will be created inside the specified videos directory.
+
+### --format
+
+The format for image files saved to disk. Defaults to `png`, or optionally `jpg`.
+
+### --interval
+
+The number of seconds between frame captures. Defaults to 10 seconds.
diff --git a/data/EveryDream/environment.yaml b/data/EveryDream/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d439c04b54fdd23f9d73b35d427df2c961d6a8e9
--- /dev/null
+++ b/data/EveryDream/environment.yaml
@@ -0,0 +1,9 @@
+name: edtools
+dependencies:
+ - pandas>=1.4.3
+ - aiofiles>=22.1.0
+ - colorama>=0.4.5
+ - aiohttp>=3.8.3
+ - timm
+ - fairscale==0.4.4
+ - transformers==4.19.2
diff --git a/data/EveryDream/input/.gitkeep b/data/EveryDream/input/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/EveryDream/laion/Put LAION parquets here.txt b/data/EveryDream/laion/Put LAION parquets here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..73372981a9c462c6443563b44a6651a9836e87ba
--- /dev/null
+++ b/data/EveryDream/laion/Put LAION parquets here.txt
@@ -0,0 +1,2 @@
+Suggested set is here: https://huggingface.co/datasets/laion/laion2B-en-aesthetic
+this is the default folder unless specified otherwise
\ No newline at end of file
diff --git a/data/EveryDream/output/.gitkeep b/data/EveryDream/output/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/EveryDream/requirements.txt b/data/EveryDream/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b3da28cd08caf303e97f8fdfb5d8a283e2d24642
--- /dev/null
+++ b/data/EveryDream/requirements.txt
@@ -0,0 +1,10 @@
+pandas>=1.4.3
+pyarrow>=9.0.0
+aiofiles>=22.1.0
+colorama>=0.4.5
+aiohttp>=3.8.3
+#open_clip_torch>=1.26.12
+timm
+fairscale==0.4.4
+transformers==4.19.2
+opencv-python>=4.6.0
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/BLIP.gif b/data/EveryDream/scripts/BLIP/BLIP.gif
new file mode 100644
index 0000000000000000000000000000000000000000..f97959778a4d3a9c1d5c06793c96d96204fe2081
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/BLIP.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7757a1a1133807158ec4e696a8187f289e64c30a86aa470d8e0a93948a02be22
+size 6707660
diff --git a/data/EveryDream/scripts/BLIP/CODEOWNERS b/data/EveryDream/scripts/BLIP/CODEOWNERS
new file mode 100644
index 0000000000000000000000000000000000000000..522fa4a0f715cd0328b9b9dbacae00e060193f43
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/CODEOWNERS
@@ -0,0 +1,2 @@
+# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
+#ECCN:Open Source
diff --git a/data/EveryDream/scripts/BLIP/CODE_OF_CONDUCT.md b/data/EveryDream/scripts/BLIP/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6724718c9512d730bb7f1bcc5848cd420241407
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/CODE_OF_CONDUCT.md
@@ -0,0 +1,105 @@
+# Salesforce Open Source Community Code of Conduct
+
+## About the Code of Conduct
+
+Equality is a core value at Salesforce. We believe a diverse and inclusive
+community fosters innovation and creativity, and are committed to building a
+culture where everyone feels included.
+
+Salesforce open-source projects are committed to providing a friendly, safe, and
+welcoming environment for all, regardless of gender identity and expression,
+sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
+race, age, religion, level of experience, education, socioeconomic status, or
+other similar personal characteristics.
+
+The goal of this code of conduct is to specify a baseline standard of behavior so
+that people with different social values and communication styles can work
+together effectively, productively, and respectfully in our open source community.
+It also establishes a mechanism for reporting issues and resolving conflicts.
+
+All questions and reports of abusive, harassing, or otherwise unacceptable behavior
+in a Salesforce open-source project may be reported by contacting the Salesforce
+Open Source Conduct Committee at ossconduct@salesforce.com.
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to making participation in our project and
+our community a harassment-free experience for everyone, regardless of gender
+identity and expression, sexual orientation, disability, physical appearance,
+body size, ethnicity, nationality, race, age, religion, level of experience, education,
+socioeconomic status, or other similar personal characteristics.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy toward other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+advances
+* Personal attacks, insulting/derogatory comments, or trolling
+* Public or private harassment
+* Publishing, or threatening to publish, others' private information—such as
+a physical or electronic address—without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+professional setting
+* Advocating for or encouraging any of the above behaviors
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned with this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies both within project spaces and in public spaces
+when an individual is representing the project or its community. Examples of
+representing a project or community include using an official project email
+address, posting via an official social media account, or acting as an appointed
+representative at an online or offline event. Representation of a project may be
+further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the Salesforce Open Source Conduct Committee
+at ossconduct@salesforce.com. All complaints will be reviewed and investigated
+and will result in a response that is deemed necessary and appropriate to the
+circumstances. The committee is obligated to maintain confidentiality with
+regard to the reporter of an incident. Further details of specific enforcement
+policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership and the Salesforce Open Source Conduct
+Committee.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
+version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
+It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
+[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
+
+This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
+
+[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
+[golang-coc]: https://golang.org/conduct
+[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
+[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
+[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
diff --git a/data/EveryDream/scripts/BLIP/LICENSE.txt b/data/EveryDream/scripts/BLIP/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a63e87f4e1e90c96861648a16a7304d97d3c3f7b
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/LICENSE.txt
@@ -0,0 +1,12 @@
+Copyright (c) 2022, Salesforce.com, Inc.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/data/EveryDream/scripts/BLIP/README.md b/data/EveryDream/scripts/BLIP/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7923e2119361be467ef4da36ad167cdcc1344e08
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/README.md
@@ -0,0 +1,116 @@
+## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
+
+## Announcement: BLIP is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications!
+
+
+
+This is the PyTorch code of the BLIP paper [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
+To install the dependencies, run pip install -r requirements.txt
+
+Catalog:
+- [x] Inference demo
+- [x] Pre-trained and finetuned checkpoints
+- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
+- [x] Pre-training code
+- [x] Zero-shot video-text retrieval
+- [x] Download of bootstrapped pre-training datasets
+
+
+### Inference demo:
+Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
+The demo includes code for:
+1. Image captioning
+2. Open-ended visual question answering
+3. Multimodal / unimodal feature extraction
+4. Image-text matching
+
+Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
+
+Replicate web demo and Docker image is also available at [](https://replicate.com/salesforce/blip)
+
+### Pre-trained checkpoints:
+Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
+--- | :---: | :---: | :---:
+14M | Download| - | -
+129M | Download| Download | Download
+
+### Finetuned checkpoints:
+Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
+--- | :---: | :---: | :---:
+Image-Text Retrieval (COCO) | Download| - | Download
+Image-Text Retrieval (Flickr30k) | Download| - | Download
+Image Captioning (COCO) | - | Download| Download |
+VQA | Download| Download | -
+NLVR2 | Download| - | -
+
+
+### Image-Text Retrieval:
+1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
+2. To evaluate the finetuned BLIP model on COCO, run:
+python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
+--config ./configs/retrieval_coco.yaml \
+--output_dir output/retrieval_coco \
+--evaluate
+3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
+--config ./configs/retrieval_coco.yaml \
+--output_dir output/retrieval_coco
+
+### Image-Text Captioning:
+1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
+2. To evaluate the finetuned BLIP model on COCO, run:
+python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate
+3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
+python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py
+4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=8 train_caption.py
+
+### VQA:
+1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
+2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
+python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate
+3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=16 train_vqa.py
+
+### NLVR2:
+1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
+2. To evaluate the finetuned BLIP model, run
+python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate
+3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py
+
+### Finetune with ViT-L:
+In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). Gradient checkpoint can also be activated in the config file to reduce GPU memory usage.
+
+### Pre-train:
+1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
+2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
+3. Pre-train the model using 8 A100 GPUs:
+python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain
+
+### Zero-shot video-text retrieval:
+1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
+2. Install [decord](https://github.com/dmlc/decord) with pip install decord
+3. To perform zero-shot evaluation, run
+python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py
+
+### Pre-training datasets download:
+We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
+
+Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
+--- | :---: | :---: | :---:
+CC3M+CC12M+SBU | Download| Download| Download
+LAION115M | Download| Download| Download
+
+### Citation
+If you find this code to be useful for your research, please consider citing.
+
+@inproceedings{li2022blip,
+ title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
+ author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
+ year={2022},
+ booktitle={ICML},
+}
+
+### Acknowledgement
+The implementation of BLIP relies on resources from ALBEF, Huggingface Transformers, and timm. We thank the original authors for their open-sourcing.
diff --git a/data/EveryDream/scripts/BLIP/SECURITY.md b/data/EveryDream/scripts/BLIP/SECURITY.md
new file mode 100644
index 0000000000000000000000000000000000000000..8249025739809035264e7776583b2f3ec100553c
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/SECURITY.md
@@ -0,0 +1,7 @@
+## Security
+
+Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
+as soon as it is discovered. This library limits its runtime dependencies in
+order to reduce the total cost of ownership as much as can be, but all consumers
+should remain vigilant and have their security stakeholders review all third-party
+products (3PP) like this one and their dependencies.
diff --git a/data/EveryDream/scripts/BLIP/cog.yaml b/data/EveryDream/scripts/BLIP/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c1dfcc430a4cab0fdd2a60a682336219a61c4a4f
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/cog.yaml
@@ -0,0 +1,17 @@
+build:
+ gpu: true
+ cuda: "11.1"
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "ipython==7.30.1"
+ - "torchvision==0.11.1"
+ - "torch==1.10.0"
+ - "timm==0.4.12"
+ - "transformers==4.15.0"
+ - "fairscale==0.4.4"
+ - "pycocoevalcap==1.2"
+
+predict: "predict.py:Predictor"
diff --git a/data/EveryDream/scripts/BLIP/configs/bert_config.json b/data/EveryDream/scripts/BLIP/configs/bert_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..3ef38aabc7f966b53079e9d559dc59e459cc0051
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/bert_config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ "vocab_size": 30522,
+ "encoder_width": 768,
+ "add_cross_attention": true
+}
diff --git a/data/EveryDream/scripts/BLIP/configs/caption_coco.yaml b/data/EveryDream/scripts/BLIP/configs/caption_coco.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..42eab7030c0310ba2f265baf36fa1400aa6e5846
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/caption_coco.yaml
@@ -0,0 +1,33 @@
+image_root: '/export/share/datasets/vision/coco/images/'
+ann_root: 'annotation'
+coco_gt_root: 'annotation/coco_gt'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
+
+# size of vit model; base or large
+vit: 'base'
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+batch_size: 32
+init_lr: 1e-5
+
+# vit: 'large'
+# vit_grad_ckpt: True
+# vit_ckpt_layer: 5
+# batch_size: 16
+# init_lr: 2e-6
+
+image_size: 384
+
+# generation configs
+max_length: 20
+min_length: 5
+num_beams: 3
+prompt: 'a picture of '
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 5
+
diff --git a/data/EveryDream/scripts/BLIP/configs/med_config.json b/data/EveryDream/scripts/BLIP/configs/med_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0ffad0a6f3c2f9f11b8faa84529d9860bb70327a
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/med_config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ "vocab_size": 30524,
+ "encoder_width": 768,
+ "add_cross_attention": true
+}
diff --git a/data/EveryDream/scripts/BLIP/configs/nlvr.yaml b/data/EveryDream/scripts/BLIP/configs/nlvr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d1122aadb1a776bd347068233096b0c984f648b
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/nlvr.yaml
@@ -0,0 +1,21 @@
+image_root: '/export/share/datasets/vision/NLVR2/'
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
+
+#size of vit model; base or large
+vit: 'base'
+batch_size_train: 16
+batch_size_test: 64
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+max_epoch: 15
+
+image_size: 384
+
+# optimizer
+weight_decay: 0.05
+init_lr: 3e-5
+min_lr: 0
+
diff --git a/data/EveryDream/scripts/BLIP/configs/nocaps.yaml b/data/EveryDream/scripts/BLIP/configs/nocaps.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9028135859b94aef5324c85c80e376c609d8a089
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/nocaps.yaml
@@ -0,0 +1,15 @@
+image_root: '/export/share/datasets/vision/nocaps/'
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
+
+vit: 'base'
+batch_size: 32
+
+image_size: 384
+
+max_length: 20
+min_length: 5
+num_beams: 3
+prompt: 'a picture of '
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/configs/pretrain.yaml b/data/EveryDream/scripts/BLIP/configs/pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..02355ee0228932803c661616485bf315e862b826
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/pretrain.yaml
@@ -0,0 +1,27 @@
+train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
+ '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
+ ]
+laion_path: ''
+
+# size of vit model; base or large
+vit: 'base'
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+
+image_size: 224
+batch_size: 75
+
+queue_size: 57600
+alpha: 0.4
+
+# optimizer
+weight_decay: 0.05
+init_lr: 3e-4
+min_lr: 1e-6
+warmup_lr: 1e-6
+lr_decay_rate: 0.9
+max_epoch: 20
+warmup_steps: 3000
+
+
+
diff --git a/data/EveryDream/scripts/BLIP/configs/retrieval_coco.yaml b/data/EveryDream/scripts/BLIP/configs/retrieval_coco.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a8569e9b67112fe3605ac25e4fdc0231f7975378
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/retrieval_coco.yaml
@@ -0,0 +1,34 @@
+image_root: '/export/share/datasets/vision/coco/images/'
+ann_root: 'annotation'
+dataset: 'coco'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
+
+# size of vit model; base or large
+
+vit: 'base'
+batch_size_train: 32
+batch_size_test: 64
+vit_grad_ckpt: True
+vit_ckpt_layer: 4
+init_lr: 1e-5
+
+# vit: 'large'
+# batch_size_train: 16
+# batch_size_test: 32
+# vit_grad_ckpt: True
+# vit_ckpt_layer: 12
+# init_lr: 5e-6
+
+image_size: 384
+queue_size: 57600
+alpha: 0.4
+k_test: 256
+negative_all_rank: True
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 6
+
diff --git a/data/EveryDream/scripts/BLIP/configs/retrieval_flickr.yaml b/data/EveryDream/scripts/BLIP/configs/retrieval_flickr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d75ea4eed87c9a001523c5e5914998c5e737594d
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/retrieval_flickr.yaml
@@ -0,0 +1,34 @@
+image_root: '/export/share/datasets/vision/flickr30k/'
+ann_root: 'annotation'
+dataset: 'flickr'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
+
+# size of vit model; base or large
+
+vit: 'base'
+batch_size_train: 32
+batch_size_test: 64
+vit_grad_ckpt: True
+vit_ckpt_layer: 4
+init_lr: 1e-5
+
+# vit: 'large'
+# batch_size_train: 16
+# batch_size_test: 32
+# vit_grad_ckpt: True
+# vit_ckpt_layer: 10
+# init_lr: 5e-6
+
+image_size: 384
+queue_size: 57600
+alpha: 0.4
+k_test: 128
+negative_all_rank: False
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 6
+
diff --git a/data/EveryDream/scripts/BLIP/configs/retrieval_msrvtt.yaml b/data/EveryDream/scripts/BLIP/configs/retrieval_msrvtt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..395f62542bb22d706b8e19e2455d2c7298984d0b
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/retrieval_msrvtt.yaml
@@ -0,0 +1,12 @@
+video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
+
+# size of vit model; base or large
+vit: 'base'
+batch_size: 64
+k_test: 128
+image_size: 384
+num_frm_test: 8
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/configs/vqa.yaml b/data/EveryDream/scripts/BLIP/configs/vqa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..74327e6d0a34672023b44569558fe8beeb052548
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/configs/vqa.yaml
@@ -0,0 +1,25 @@
+vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
+vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
+train_files: ['vqa_train','vqa_val','vg_qa']
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
+
+# size of vit model; base or large
+vit: 'base'
+batch_size_train: 16
+batch_size_test: 32
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+init_lr: 2e-5
+
+image_size: 480
+
+k_test: 128
+inference: 'rank'
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 10
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/data/__init__.py b/data/EveryDream/scripts/BLIP/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be209acf415855ea6ef753efedf903b5decb6b9
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/__init__.py
@@ -0,0 +1,101 @@
+import torch
+from torch.utils.data import DataLoader
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
+from data.nocaps_dataset import nocaps_eval
+from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
+from data.vqa_dataset import vqa_dataset
+from data.nlvr_dataset import nlvr_dataset
+from data.pretrain_dataset import pretrain_dataset
+from transform.randaugment import RandomAugment
+
+def create_dataset(dataset, config, min_scale=0.5):
+
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+
+ transform_train = transforms.Compose([
+ transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomHorizontalFlip(),
+ RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ normalize,
+ ])
+
+ if dataset=='pretrain':
+ dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
+ return dataset
+
+ elif dataset=='caption_coco':
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
+ val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return train_dataset, val_dataset, test_dataset
+
+ elif dataset=='nocaps':
+ val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return val_dataset, test_dataset
+
+ elif dataset=='retrieval_coco':
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
+ val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return train_dataset, val_dataset, test_dataset
+
+ elif dataset=='retrieval_flickr':
+ train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
+ val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return train_dataset, val_dataset, test_dataset
+
+ elif dataset=='vqa':
+ train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
+ train_files = config['train_files'], split='train')
+ test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
+ return train_dataset, test_dataset
+
+ elif dataset=='nlvr':
+ train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
+ val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
+ test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
+ return train_dataset, val_dataset, test_dataset
+
+
+def create_sampler(datasets, shuffles, num_tasks, global_rank):
+ samplers = []
+ for dataset,shuffle in zip(datasets,shuffles):
+ sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
+ samplers.append(sampler)
+ return samplers
+
+
+def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
+ loaders = []
+ for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
+ if is_train:
+ shuffle = (sampler is None)
+ drop_last = True
+ else:
+ shuffle = False
+ drop_last = False
+ loader = DataLoader(
+ dataset,
+ batch_size=bs,
+ num_workers=n_worker,
+ pin_memory=True,
+ sampler=sampler,
+ shuffle=shuffle,
+ collate_fn=collate_fn,
+ drop_last=drop_last,
+ )
+ loaders.append(loader)
+ return loaders
+
diff --git a/data/EveryDream/scripts/BLIP/data/coco_karpathy_dataset.py b/data/EveryDream/scripts/BLIP/data/coco_karpathy_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34d29205f42aa09695b160ac9c91958ba041bb3
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/coco_karpathy_dataset.py
@@ -0,0 +1,126 @@
+import os
+import json
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+from data.utils import pre_caption
+
+class coco_karpathy_train(Dataset):
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
+ '''
+ image_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ '''
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
+ filename = 'coco_karpathy_train.json'
+
+ download_url(url,ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
+ self.transform = transform
+ self.image_root = image_root
+ self.max_words = max_words
+ self.prompt = prompt
+
+ self.img_ids = {}
+ n = 0
+ for ann in self.annotation:
+ img_id = ann['image_id']
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
+
+ return image, caption, self.img_ids[ann['image_id']]
+
+
+class coco_karpathy_caption_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split):
+ '''
+ image_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ '''
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
+
+ return image, int(img_id)
+
+
+class coco_karpathy_retrieval_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
+ '''
+ image_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ '''
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ self.text = []
+ self.image = []
+ self.txt2img = {}
+ self.img2txt = {}
+
+ txt_id = 0
+ for img_id, ann in enumerate(self.annotation):
+ self.image.append(ann['image'])
+ self.img2txt[img_id] = []
+ for i, caption in enumerate(ann['caption']):
+ self.text.append(pre_caption(caption,max_words))
+ self.img2txt[img_id].append(txt_id)
+ self.txt2img[txt_id] = img_id
+ txt_id += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ return image, index
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/data/flickr30k_dataset.py b/data/EveryDream/scripts/BLIP/data/flickr30k_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..018ab387014ddaf554c4d3184cfc0e2ba8b2d487
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/flickr30k_dataset.py
@@ -0,0 +1,93 @@
+import os
+import json
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+from data.utils import pre_caption
+
+class flickr30k_train(Dataset):
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
+ '''
+ image_root (string): Root directory of images (e.g. flickr30k/)
+ ann_root (string): directory to store the annotation file
+ '''
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
+ filename = 'flickr30k_train.json'
+
+ download_url(url,ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
+ self.transform = transform
+ self.image_root = image_root
+ self.max_words = max_words
+ self.prompt = prompt
+
+ self.img_ids = {}
+ n = 0
+ for ann in self.annotation:
+ img_id = ann['image_id']
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
+
+ return image, caption, self.img_ids[ann['image_id']]
+
+
+class flickr30k_retrieval_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
+ '''
+ image_root (string): Root directory of images (e.g. flickr30k/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ '''
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ self.text = []
+ self.image = []
+ self.txt2img = {}
+ self.img2txt = {}
+
+ txt_id = 0
+ for img_id, ann in enumerate(self.annotation):
+ self.image.append(ann['image'])
+ self.img2txt[img_id] = []
+ for i, caption in enumerate(ann['caption']):
+ self.text.append(pre_caption(caption,max_words))
+ self.img2txt[img_id].append(txt_id)
+ self.txt2img[txt_id] = img_id
+ txt_id += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ return image, index
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/data/nlvr_dataset.py b/data/EveryDream/scripts/BLIP/data/nlvr_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8d6b2d7cd8d3260bd279c7dca80de53bacc691a
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/nlvr_dataset.py
@@ -0,0 +1,78 @@
+import os
+import json
+import random
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+from data.utils import pre_caption
+
+class nlvr_dataset(Dataset):
+ def __init__(self, transform, image_root, ann_root, split):
+ '''
+ image_root (string): Root directory of images
+ ann_root (string): directory to store the annotation file
+ split (string): train, val or test
+ '''
+ urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
+ 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
+ filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
+
+ download_url(urls[split],ann_root)
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+
+ self.transform = transform
+ self.image_root = image_root
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image0_path = os.path.join(self.image_root,ann['images'][0])
+ image0 = Image.open(image0_path).convert('RGB')
+ image0 = self.transform(image0)
+
+ image1_path = os.path.join(self.image_root,ann['images'][1])
+ image1 = Image.open(image1_path).convert('RGB')
+ image1 = self.transform(image1)
+
+ sentence = pre_caption(ann['sentence'], 40)
+
+ if ann['label']=='True':
+ label = 1
+ else:
+ label = 0
+
+ words = sentence.split(' ')
+
+ if 'left' not in words and 'right' not in words:
+ if random.random()<0.5:
+ return image0, image1, sentence, label
+ else:
+ return image1, image0, sentence, label
+ else:
+ if random.random()<0.5:
+ return image0, image1, sentence, label
+ else:
+ new_words = []
+ for word in words:
+ if word=='left':
+ new_words.append('right')
+ elif word=='right':
+ new_words.append('left')
+ else:
+ new_words.append(word)
+
+ sentence = ' '.join(new_words)
+ return image1, image0, sentence, label
+
+
+
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/data/nocaps_dataset.py b/data/EveryDream/scripts/BLIP/data/nocaps_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0bed06d8af3dbaccf18a56e725f101e585503e
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/nocaps_dataset.py
@@ -0,0 +1,32 @@
+import os
+import json
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+class nocaps_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split):
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
+ filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ return image, int(ann['img_id'])
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/data/pretrain_dataset.py b/data/EveryDream/scripts/BLIP/data/pretrain_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..703d543ab5267fdc6fe2b7c84ef6a631d8af90ad
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/pretrain_dataset.py
@@ -0,0 +1,59 @@
+import json
+import os
+import random
+
+from torch.utils.data import Dataset
+
+from PIL import Image
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+Image.MAX_IMAGE_PIXELS = None
+
+from data.utils import pre_caption
+import os,glob
+
+class pretrain_dataset(Dataset):
+ def __init__(self, ann_file, laion_path, transform):
+
+ self.ann_pretrain = []
+ for f in ann_file:
+ print('loading '+f)
+ ann = json.load(open(f,'r'))
+ self.ann_pretrain += ann
+
+ self.laion_path = laion_path
+ if self.laion_path:
+ self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
+
+ print('loading '+self.laion_files[0])
+ with open(self.laion_files[0],'r') as f:
+ self.ann_laion = json.load(f)
+
+ self.annotation = self.ann_pretrain + self.ann_laion
+ else:
+ self.annotation = self.ann_pretrain
+
+ self.transform = transform
+
+
+ def reload_laion(self, epoch):
+ n = epoch%len(self.laion_files)
+ print('loading '+self.laion_files[n])
+ with open(self.laion_files[n],'r') as f:
+ self.ann_laion = json.load(f)
+
+ self.annotation = self.ann_pretrain + self.ann_laion
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image = Image.open(ann['image']).convert('RGB')
+ image = self.transform(image)
+ caption = pre_caption(ann['caption'],30)
+
+ return image, caption
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/data/utils.py b/data/EveryDream/scripts/BLIP/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..628894844becd462d444584b8b2b01a84ee4b8f7
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/utils.py
@@ -0,0 +1,112 @@
+import re
+import json
+import os
+
+import torch
+import torch.distributed as dist
+
+import utils
+
+def pre_caption(caption,max_words=50):
+ caption = re.sub(
+ r"([.!\"()*#:;~])",
+ ' ',
+ caption.lower(),
+ )
+ caption = re.sub(
+ r"\s{2,}",
+ ' ',
+ caption,
+ )
+ caption = caption.rstrip('\n')
+ caption = caption.strip(' ')
+
+ #truncate caption
+ caption_words = caption.split(' ')
+ if len(caption_words)>max_words:
+ caption = ' '.join(caption_words[:max_words])
+
+ return caption
+
+def pre_question(question,max_ques_words=50):
+ question = re.sub(
+ r"([.!\"()*#:;~])",
+ '',
+ question.lower(),
+ )
+ question = question.rstrip(' ')
+
+ #truncate question
+ question_words = question.split(' ')
+ if len(question_words)>max_ques_words:
+ question = ' '.join(question_words[:max_ques_words])
+
+ return question
+
+
+def save_result(result, result_dir, filename, remove_duplicate=''):
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
+ final_result_file = os.path.join(result_dir, '%s.json'%filename)
+
+ json.dump(result,open(result_file,'w'))
+
+ dist.barrier()
+
+ if utils.is_main_process():
+ # combine results from all processes
+ result = []
+
+ for rank in range(utils.get_world_size()):
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
+ res = json.load(open(result_file,'r'))
+ result += res
+
+ if remove_duplicate:
+ result_new = []
+ id_list = []
+ for res in result:
+ if res[remove_duplicate] not in id_list:
+ id_list.append(res[remove_duplicate])
+ result_new.append(res)
+ result = result_new
+
+ json.dump(result,open(final_result_file,'w'))
+ print('result file saved to %s'%final_result_file)
+
+ return final_result_file
+
+
+
+from pycocotools.coco import COCO
+from pycocoevalcap.eval import COCOEvalCap
+from torchvision.datasets.utils import download_url
+
+def coco_caption_eval(coco_gt_root, results_file, split):
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
+ filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
+
+ download_url(urls[split],coco_gt_root)
+ annotation_file = os.path.join(coco_gt_root,filenames[split])
+
+ # create coco object and coco_result object
+ coco = COCO(annotation_file)
+ coco_result = coco.loadRes(results_file)
+
+ # create coco_eval object by taking coco and coco_result
+ coco_eval = COCOEvalCap(coco, coco_result)
+
+ # evaluate on a subset of images by setting
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
+ # please remove this line when evaluating the full validation set
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
+
+ # evaluate results
+ # SPICE will take a few minutes the first time, but speeds up due to caching
+ coco_eval.evaluate()
+
+ # print output evaluation scores
+ for metric, score in coco_eval.eval.items():
+ print(f'{metric}: {score:.3f}')
+
+ return coco_eval
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/data/video_dataset.py b/data/EveryDream/scripts/BLIP/data/video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6f8a61105bbd4285f98b3abe9445b73fd4c7ef
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/video_dataset.py
@@ -0,0 +1,110 @@
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+import torch
+import numpy as np
+import random
+import decord
+from decord import VideoReader
+import json
+import os
+from data.utils import pre_caption
+
+decord.bridge.set_bridge("torch")
+
+class ImageNorm(object):
+ """Apply Normalization to Image Pixels on GPU
+ """
+ def __init__(self, mean, std):
+ self.mean = torch.tensor(mean).view(1, 3, 1, 1)
+ self.std = torch.tensor(std).view(1, 3, 1, 1)
+
+ def __call__(self, img):
+
+ if torch.max(img) > 1 and self.mean.max() <= 1:
+ img.div_(255.)
+ return img.sub_(self.mean).div_(self.std)
+
+def load_jsonl(filename):
+ with open(filename, "r") as f:
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
+
+
+class VideoDataset(Dataset):
+
+ def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
+ '''
+ image_root (string): Root directory of video
+ ann_root (string): directory to store the annotation file
+ '''
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
+ filename = 'msrvtt_test.jsonl'
+
+ download_url(url,ann_root)
+ self.annotation = load_jsonl(os.path.join(ann_root,filename))
+
+ self.num_frm = num_frm
+ self.frm_sampling_strategy = frm_sampling_strategy
+ self.max_img_size = max_img_size
+ self.video_root = video_root
+ self.video_fmt = video_fmt
+ self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
+
+ self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
+ self.txt2video = [i for i in range(len(self.annotation))]
+ self.video2txt = self.txt2video
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
+
+ vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
+
+ video = self.img_norm(vid_frm_array.float())
+
+ return video, ann['clip_name']
+
+
+
+ def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
+ try:
+ if not height or not width:
+ vr = VideoReader(video_path)
+ else:
+ vr = VideoReader(video_path, width=width, height=height)
+
+ vlen = len(vr)
+
+ if start_time or end_time:
+ assert fps > 0, 'must provide video fps if specifying start and end time.'
+
+ start_idx = min(int(start_time * fps), vlen)
+ end_idx = min(int(end_time * fps), vlen)
+ else:
+ start_idx, end_idx = 0, vlen
+
+ if self.frm_sampling_strategy == 'uniform':
+ frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
+ elif self.frm_sampling_strategy == 'rand':
+ frame_indices = sorted(random.sample(range(vlen), self.num_frm))
+ elif self.frm_sampling_strategy == 'headtail':
+ frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
+ frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
+ frame_indices = frame_indices_head + frame_indices_tail
+ else:
+ raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
+
+ raw_sample_frms = vr.get_batch(frame_indices)
+ except Exception as e:
+ return None
+
+ raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
+
+ return raw_sample_frms
diff --git a/data/EveryDream/scripts/BLIP/data/vqa_dataset.py b/data/EveryDream/scripts/BLIP/data/vqa_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ec1df429b3910316ddd554bfea01c6e7922cae
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/data/vqa_dataset.py
@@ -0,0 +1,88 @@
+import os
+import json
+import random
+from PIL import Image
+
+import torch
+from torch.utils.data import Dataset
+from data.utils import pre_question
+
+from torchvision.datasets.utils import download_url
+
+class vqa_dataset(Dataset):
+ def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
+ self.split = split
+
+ self.transform = transform
+ self.vqa_root = vqa_root
+ self.vg_root = vg_root
+
+ if split=='train':
+ urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
+ 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
+ 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
+
+ self.annotation = []
+ for f in train_files:
+ download_url(urls[f],ann_root)
+ self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
+ else:
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
+ self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
+
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
+ self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ if ann['dataset']=='vqa':
+ image_path = os.path.join(self.vqa_root,ann['image'])
+ elif ann['dataset']=='vg':
+ image_path = os.path.join(self.vg_root,ann['image'])
+
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ if self.split == 'test':
+ question = pre_question(ann['question'])
+ question_id = ann['question_id']
+ return image, question, question_id
+
+
+ elif self.split=='train':
+
+ question = pre_question(ann['question'])
+
+ if ann['dataset']=='vqa':
+ answer_weight = {}
+ for answer in ann['answer']:
+ if answer in answer_weight.keys():
+ answer_weight[answer] += 1/len(ann['answer'])
+ else:
+ answer_weight[answer] = 1/len(ann['answer'])
+
+ answers = list(answer_weight.keys())
+ weights = list(answer_weight.values())
+
+ elif ann['dataset']=='vg':
+ answers = [ann['answer']]
+ weights = [0.2]
+
+ return image, question, answers, weights
+
+
+def vqa_collate_fn(batch):
+ image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
+ for image, question, answer, weights in batch:
+ image_list.append(image)
+ question_list.append(question)
+ weight_list += weights
+ answer_list += answer
+ n.append(len(answer))
+ return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/demo.ipynb b/data/EveryDream/scripts/BLIP/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..3077a1a42c584f3ef535903f64cf6b3bb722490e
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/demo.ipynb
@@ -0,0 +1,301 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "2b949f9f",
+ "metadata": {},
+ "source": [
+ "# BLIP: Inference Demo\n",
+ " - [Image Captioning](#Image-Captioning)\n",
+ " - [VQA](#VQA)\n",
+ " - [Feature Extraction](#Feature-Extraction)\n",
+ " - [Image Text Matching](#Image-Text-Matching)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "cbcb066b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# install requirements\n",
+ "import sys\n",
+ "if 'google.colab' in sys.modules:\n",
+ " print('Running in Colab.')\n",
+ " !pip3 install transformers==4.15.0 timm==0.4.12 fairscale==0.4.4\n",
+ " !git clone https://github.com/salesforce/BLIP\n",
+ " %cd BLIP"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "a811a65f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from PIL import Image\n",
+ "import requests\n",
+ "import torch\n",
+ "from torchvision import transforms\n",
+ "from torchvision.transforms.functional import InterpolationMode\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "\n",
+ "def load_demo_image(image_size,device):\n",
+ " img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' \n",
+ " raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') \n",
+ "\n",
+ " w,h = raw_image.size\n",
+ " display(raw_image.resize((w//5,h//5)))\n",
+ " \n",
+ " transform = transforms.Compose([\n",
+ " transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
+ " ]) \n",
+ " image = transform(raw_image).unsqueeze(0).to(device) \n",
+ " return image"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f72f4406",
+ "metadata": {},
+ "source": [
+ "# Image Captioning\n",
+ "Perform image captioning using finetuned BLIP model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "6835daef",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n",
+ "caption: a woman sitting on the beach with a dog\n"
+ ]
+ }
+ ],
+ "source": [
+ "from models.blip import blip_decoder\n",
+ "\n",
+ "image_size = 384\n",
+ "image = load_demo_image(image_size=image_size, device=device)\n",
+ "\n",
+ "model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'\n",
+ " \n",
+ "model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')\n",
+ "model.eval()\n",
+ "model = model.to(device)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " # beam search\n",
+ " caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5) \n",
+ " # nucleus sampling\n",
+ " # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) \n",
+ " print('caption: '+caption[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fac320a2",
+ "metadata": {},
+ "source": [
+ "# VQA\n",
+ "Perform visual question answering using finetuned BLIP model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "5e6f3fb1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth\n",
+ "answer: on beach\n"
+ ]
+ }
+ ],
+ "source": [
+ "from models.blip_vqa import blip_vqa\n",
+ "\n",
+ "image_size = 480\n",
+ "image = load_demo_image(image_size=image_size, device=device) \n",
+ "\n",
+ "model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'\n",
+ " \n",
+ "model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')\n",
+ "model.eval()\n",
+ "model = model.to(device)\n",
+ "\n",
+ "question = 'where is the woman sitting?'\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " answer = model(image, question, train=False, inference='generate') \n",
+ " print('answer: '+answer[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6100e519",
+ "metadata": {},
+ "source": [
+ "# Feature Extraction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "4f8f21ed",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth\n"
+ ]
+ }
+ ],
+ "source": [
+ "from models.blip import blip_feature_extractor\n",
+ "\n",
+ "image_size = 224\n",
+ "image = load_demo_image(image_size=image_size, device=device) \n",
+ "\n",
+ "model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'\n",
+ " \n",
+ "model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')\n",
+ "model.eval()\n",
+ "model = model.to(device)\n",
+ "\n",
+ "caption = 'a woman sitting on the beach with a dog'\n",
+ "\n",
+ "multimodal_feature = model(image, caption, mode='multimodal')[0,0]\n",
+ "image_feature = model(image, caption, mode='image')[0,0]\n",
+ "text_feature = model(image, caption, mode='text')[0,0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "201e1146",
+ "metadata": {},
+ "source": [
+ "# Image-Text Matching"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "49ba5906",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth\n",
+ "text: a woman sitting on the beach with a dog\n",
+ "The image and text is matched with a probability of 0.9960\n",
+ "The image feature and text feature has a cosine similarity of 0.5262\n"
+ ]
+ }
+ ],
+ "source": [
+ "from models.blip_itm import blip_itm\n",
+ "\n",
+ "image_size = 384\n",
+ "image = load_demo_image(image_size=image_size,device=device)\n",
+ "\n",
+ "model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'\n",
+ " \n",
+ "model = blip_itm(pretrained=model_url, image_size=image_size, vit='base')\n",
+ "model.eval()\n",
+ "model = model.to(device='cpu')\n",
+ "\n",
+ "caption = 'a woman sitting on the beach with a dog'\n",
+ "\n",
+ "print('text: %s' %caption)\n",
+ "\n",
+ "itm_output = model(image,caption,match_head='itm')\n",
+ "itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1]\n",
+ "print('The image and text is matched with a probability of %.4f'%itm_score)\n",
+ "\n",
+ "itc_score = model(image,caption,match_head='itc')\n",
+ "print('The image feature and text feature has a cosine similarity of %.4f'%itc_score)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/data/EveryDream/scripts/BLIP/eval_nocaps.py b/data/EveryDream/scripts/BLIP/eval_nocaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbb09a8cc7771605c013583d721aa95d9413b42
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/eval_nocaps.py
@@ -0,0 +1,118 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip import blip_decoder
+import utils
+from data import create_dataset, create_sampler, create_loader
+from data.utils import save_result
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, config):
+ # evaluate
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Evaluation:'
+ print_freq = 10
+
+ result = []
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
+
+ image = image.to(device)
+
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
+ min_length=config['min_length'], repetition_penalty=1.1)
+
+ for caption, img_id in zip(captions, image_id):
+ result.append({"image_id": img_id.item(), "caption": caption})
+
+ return result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating captioning dataset")
+ val_dataset, test_dataset = create_dataset('nocaps', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
+ else:
+ samplers = [None,None]
+
+ val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
+ batch_size=[config['batch_size']]*2,num_workers=[4,4],
+ is_trains=[False, False], collate_fns=[None,None])
+
+ #### Model ####
+ print("Creating model")
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
+ prompt=config['prompt'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
+ val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
+ test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/nocaps.yaml')
+ parser.add_argument('--output_dir', default='output/NoCaps')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ args.result_dir = os.path.join(args.output_dir, 'result')
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/eval_retrieval_video.py b/data/EveryDream/scripts/BLIP/eval_retrieval_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ebab7f41f6466f6f46130002e2e0df1266486a
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/eval_retrieval_video.py
@@ -0,0 +1,250 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip_retrieval import blip_retrieval
+import utils
+from data.video_dataset import VideoDataset
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, tokenizer, device, config):
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Evaluation:'
+
+ print('Computing features for evaluation...')
+ start_time = time.time()
+
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_ids = []
+ text_embeds = []
+ text_atts = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i: min(num_text, i+text_bs)]
+ text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
+ text_embeds.append(text_embed)
+ text_ids.append(text_input.input_ids)
+ text_atts.append(text_input.attention_mask)
+
+ text_embeds = torch.cat(text_embeds,dim=0)
+ text_ids = torch.cat(text_ids,dim=0)
+ text_atts = torch.cat(text_atts,dim=0)
+ text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
+
+ video_feats = []
+ video_embeds = []
+ for video, video_id in data_loader:
+
+ B,N,C,W,H = video.size()
+ video = video.view(-1,C,W,H)
+ video = video.to(device,non_blocking=True)
+ video_feat = model.visual_encoder(video)
+ video_embed = model.vision_proj(video_feat[:,0,:])
+ video_embed = video_embed.view(B,N,-1).mean(dim=1)
+ video_embed = F.normalize(video_embed,dim=-1)
+
+ video_feat = video_feat.view(B,-1,video_feat.shape[-1])
+ video_feats.append(video_feat.cpu())
+ video_embeds.append(video_embed)
+
+ video_feats = torch.cat(video_feats,dim=0)
+ video_embeds = torch.cat(video_embeds,dim=0)
+
+ sims_matrix = video_embeds @ text_embeds.t()
+ score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
+
+ num_tasks = utils.get_world_size()
+ rank = utils.get_rank()
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+
+ encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
+ output = model.text_encoder(text_ids[topk_idx],
+ attention_mask = text_atts[topk_idx],
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_v2t[start+i,topk_idx] = score + topk_sim
+
+ sims_matrix = sims_matrix.t()
+ score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
+
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+ encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_t2v[start+i,topk_idx] = score + topk_sim
+
+ if args.distributed:
+ dist.barrier()
+ torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
+ torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Evaluation time {}'.format(total_time_str))
+
+ return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
+
+
+
+@torch.no_grad()
+def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
+
+ #Video->Text
+ ranks = np.zeros(scores_v2t.shape[0])
+ for index,score in enumerate(scores_v2t):
+ inds = np.argsort(score)[::-1]
+ ranks[index] = np.where(inds == vid2txt[index])[0][0]
+
+ # Compute metrics
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ #Text->Video
+ ranks = np.zeros(scores_t2v.shape[0])
+
+ for index,score in enumerate(scores_t2v):
+ inds = np.argsort(score)[::-1]
+ ranks[index] = np.where(inds == txt2vmg[index])[0][0]
+
+ mdR = np.median(ranks+1)
+
+ # Compute metrics
+ vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ tr_mean = (tr1 + tr5 + tr10) / 3
+ vr_mean = (vr1 + vr5 + vr10) / 3
+ r_mean = (tr_mean + vr_mean) / 2
+
+ eval_result = {'txt_r1': tr1,
+ 'txt_r5': tr5,
+ 'txt_r10': tr10,
+ 'txt_r_mean': tr_mean,
+ 'vid_r1': vr1,
+ 'vid_r5': vr5,
+ 'vid_r10': vr10,
+ 'vid_r_mean': vr_mean,
+ 'vid_mdR': mdR,
+ 'r_mean': r_mean}
+ return eval_result
+
+
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating retrieval dataset")
+ test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
+ max_img_size=config['image_size'], frm_sampling_strategy='uniform')
+
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=config['batch_size'],
+ num_workers=4,
+ pin_memory=True,
+ drop_last=False,
+ shuffle=False,
+ )
+
+ #### Model ####
+ print("Creating model")
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
+
+ if utils.is_main_process():
+
+ test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
+ print(test_result)
+
+ log_stats = {**{f'{k}': v for k, v in test_result.items()},}
+ with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
+ parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/models/__init__.py b/data/EveryDream/scripts/BLIP/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/EveryDream/scripts/BLIP/models/__pycache__/__init__.cpython-310.pyc b/data/EveryDream/scripts/BLIP/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b3abd1ee6f58787aace7012fe4892c6f4eaf8f5
Binary files /dev/null and b/data/EveryDream/scripts/BLIP/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/data/EveryDream/scripts/BLIP/models/__pycache__/blip.cpython-310.pyc b/data/EveryDream/scripts/BLIP/models/__pycache__/blip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2dcc0af80321bd9ef05154371948bdb7584c3510
Binary files /dev/null and b/data/EveryDream/scripts/BLIP/models/__pycache__/blip.cpython-310.pyc differ
diff --git a/data/EveryDream/scripts/BLIP/models/__pycache__/med.cpython-310.pyc b/data/EveryDream/scripts/BLIP/models/__pycache__/med.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf87fef2eb713b2ba5f288848e12e81994d72d91
Binary files /dev/null and b/data/EveryDream/scripts/BLIP/models/__pycache__/med.cpython-310.pyc differ
diff --git a/data/EveryDream/scripts/BLIP/models/__pycache__/vit.cpython-310.pyc b/data/EveryDream/scripts/BLIP/models/__pycache__/vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f35814715089e37f100efc69dea76a888c42a41b
Binary files /dev/null and b/data/EveryDream/scripts/BLIP/models/__pycache__/vit.cpython-310.pyc differ
diff --git a/data/EveryDream/scripts/BLIP/models/blip.py b/data/EveryDream/scripts/BLIP/models/blip.py
new file mode 100644
index 0000000000000000000000000000000000000000..38678f65ea2c276b351c2c97d429ebc2525ddcf7
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/blip.py
@@ -0,0 +1,238 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import warnings
+warnings.filterwarnings("ignore")
+
+from models.vit import VisionTransformer, interpolate_pos_embed
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import os
+from urllib.parse import urlparse
+from timm.models.hub import download_cached_file
+
+class BLIP_Base(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 224,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+
+ def forward(self, image, caption, mode):
+
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
+
+ if mode=='image':
+ # return image features
+ image_embeds = self.visual_encoder(image)
+ return image_embeds
+
+ elif mode=='text':
+ # return text features
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ return text_output.last_hidden_state
+
+ elif mode=='multimodal':
+ # return multimodel features
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
+ output = self.text_encoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+ return output.last_hidden_state
+
+
+
+class BLIP_Decoder(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 384,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ prompt = 'a picture of ',
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_decoder = BertLMHeadModel(config=med_config)
+
+ self.prompt = prompt
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
+
+
+ def forward(self, image, caption):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
+
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
+
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
+ decoder_targets[:,:self.prompt_length] = -100
+
+ decoder_output = self.text_decoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ labels = decoder_targets,
+ return_dict = True,
+ )
+ loss_lm = decoder_output.loss
+
+ return loss_lm
+
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
+ image_embeds = self.visual_encoder(image)
+
+ if not sample:
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
+
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
+
+ prompt = [self.prompt] * image.size(0)
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
+ input_ids[:,0] = self.tokenizer.bos_token_id
+ input_ids = input_ids[:, :-1]
+
+ if sample:
+ #nucleus sampling
+ outputs = self.text_decoder.generate(input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=True,
+ top_p=top_p,
+ num_return_sequences=1,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=1.1,
+ **model_kwargs)
+ else:
+ #beam search
+ outputs = self.text_decoder.generate(input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ num_beams=num_beams,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=repetition_penalty,
+ **model_kwargs)
+
+ captions = []
+ for output in outputs:
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
+ captions.append(caption[len(self.prompt):])
+ return captions
+
+
+def blip_decoder(pretrained='',**kwargs):
+ model = BLIP_Decoder(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ assert(len(msg.missing_keys)==0)
+ return model
+
+def blip_feature_extractor(pretrained='',**kwargs):
+ model = BLIP_Base(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ assert(len(msg.missing_keys)==0)
+ return model
+
+def init_tokenizer():
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
+ return tokenizer
+
+
+def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
+
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
+ if vit=='base':
+ vision_width = 768
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0 or drop_path_rate
+ )
+ elif vit=='large':
+ vision_width = 1024
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0.1 or drop_path_rate
+ )
+ return visual_encoder, vision_width
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+def load_checkpoint(model,url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['model']
+
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
+ model.visual_encoder_m)
+ for key in model.state_dict().keys():
+ if key in state_dict.keys():
+ if state_dict[key].shape!=model.state_dict()[key].shape:
+ del state_dict[key]
+
+ msg = model.load_state_dict(state_dict,strict=False)
+ print('load checkpoint from %s'%url_or_filename)
+ return model,msg
+
diff --git a/data/EveryDream/scripts/BLIP/models/blip_itm.py b/data/EveryDream/scripts/BLIP/models/blip_itm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf354c829564bf5a1f56089a2d745093d51e0fa2
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/blip_itm.py
@@ -0,0 +1,76 @@
+from models.med import BertConfig, BertModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_ITM(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 384,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
+ self.itm_head = nn.Linear(text_width, 2)
+
+
+ def forward(self, image, caption, match_head='itm'):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
+ return_tensors="pt").to(image.device)
+
+
+ if match_head=='itm':
+ output = self.text_encoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+ itm_output = self.itm_head(output.last_hidden_state[:,0,:])
+ return itm_output
+
+ elif match_head=='itc':
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
+
+ sim = image_feat @ text_feat.t()
+ return sim
+
+
+def blip_itm(pretrained='',**kwargs):
+ model = BLIP_ITM(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ assert(len(msg.missing_keys)==0)
+ return model
+
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/models/blip_nlvr.py b/data/EveryDream/scripts/BLIP/models/blip_nlvr.py
new file mode 100644
index 0000000000000000000000000000000000000000..84837167bfa6874d3c3e41fb9b37271113910b7f
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/blip_nlvr.py
@@ -0,0 +1,103 @@
+from models.med import BertConfig
+from models.nlvr_encoder import BertModel
+from models.vit import interpolate_pos_embed
+from models.blip import create_vit, init_tokenizer, is_url
+
+from timm.models.hub import download_cached_file
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import BertTokenizer
+import numpy as np
+
+class BLIP_NLVR(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 480,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+ self.cls_head = nn.Sequential(
+ nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
+ nn.ReLU(),
+ nn.Linear(self.text_encoder.config.hidden_size, 2)
+ )
+
+ def forward(self, image, text, targets, train=True):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
+
+ text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
+
+ output = self.text_encoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = [image0_embeds,image1_embeds],
+ encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
+ image_atts[image0_embeds.size(0):]],
+ return_dict = True,
+ )
+ hidden_state = output.last_hidden_state[:,0,:]
+ prediction = self.cls_head(hidden_state)
+
+ if train:
+ loss = F.cross_entropy(prediction, targets)
+ return loss
+ else:
+ return prediction
+
+def blip_nlvr(pretrained='',**kwargs):
+ model = BLIP_NLVR(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ print("missing keys:")
+ print(msg.missing_keys)
+ return model
+
+
+def load_checkpoint(model,url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+ state_dict = checkpoint['model']
+
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
+
+ for key in list(state_dict.keys()):
+ if 'crossattention.self.' in key:
+ new_key0 = key.replace('self','self0')
+ new_key1 = key.replace('self','self1')
+ state_dict[new_key0] = state_dict[key]
+ state_dict[new_key1] = state_dict[key]
+ elif 'crossattention.output.dense.' in key:
+ new_key0 = key.replace('dense','dense0')
+ new_key1 = key.replace('dense','dense1')
+ state_dict[new_key0] = state_dict[key]
+ state_dict[new_key1] = state_dict[key]
+
+ msg = model.load_state_dict(state_dict,strict=False)
+ print('load checkpoint from %s'%url_or_filename)
+ return model,msg
+
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/models/blip_pretrain.py b/data/EveryDream/scripts/BLIP/models/blip_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42ce5f998b0a51e6f731ee6b5c8bae6d02a8664
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/blip_pretrain.py
@@ -0,0 +1,339 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from transformers import BertTokenizer
+import transformers
+transformers.logging.set_verbosity_error()
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_Pretrain(nn.Module):
+ def __init__(self,
+ med_config = 'configs/bert_config.json',
+ image_size = 224,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ queue_size = 57600,
+ momentum = 0.995,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
+
+ if vit=='base':
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+ map_location="cpu", check_hash=True)
+ state_dict = checkpoint["model"]
+ msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
+ elif vit=='large':
+ from timm.models.helpers import load_custom_pretrained
+ from timm.models.vision_transformer import default_cfgs
+ load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
+
+ self.tokenizer = init_tokenizer()
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
+ self.itm_head = nn.Linear(text_width, 2)
+
+ # create momentum encoders
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
+ self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
+
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
+ [self.vision_proj,self.vision_proj_m],
+ [self.text_encoder,self.text_encoder_m],
+ [self.text_proj,self.text_proj_m],
+ ]
+ self.copy_params()
+
+ # create the queue
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
+
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
+
+ self.queue_size = queue_size
+ self.momentum = momentum
+ self.temp = nn.Parameter(0.07*torch.ones([]))
+
+ # create the decoder
+ decoder_config = BertConfig.from_json_file(med_config)
+ decoder_config.encoder_width = vision_width
+ self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
+ self.text_decoder.resize_token_embeddings(len(self.tokenizer))
+ tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
+
+
+ def forward(self, image, caption, alpha):
+ with torch.no_grad():
+ self.temp.clamp_(0.001,0.5)
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
+
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
+ return_tensors="pt").to(image.device)
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
+
+ # get momentum features
+ with torch.no_grad():
+ self._momentum_update()
+ image_embeds_m = self.visual_encoder_m(image)
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
+ image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
+
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
+ text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
+
+ sim_i2t_m = image_feat_m @ text_feat_all / self.temp
+ sim_t2i_m = text_feat_m @ image_feat_all / self.temp
+
+ sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
+ sim_targets.fill_diagonal_(1)
+
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
+
+ sim_i2t = image_feat @ text_feat_all / self.temp
+ sim_t2i = text_feat @ image_feat_all / self.temp
+
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
+
+ loss_ita = (loss_i2t+loss_t2i)/2
+
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m)
+
+ ###============== Image-text Matching ===================###
+ encoder_input_ids = text.input_ids.clone()
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
+
+ # forward the positve image-text pair
+ bs = image.size(0)
+ output_pos = self.text_encoder(encoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+ with torch.no_grad():
+ weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
+ weights_t2i.fill_diagonal_(0)
+ weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
+ weights_i2t.fill_diagonal_(0)
+
+ # select a negative image for each text
+ image_embeds_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+ image_embeds_neg.append(image_embeds[neg_idx])
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
+
+ # select a negative text for each image
+ text_ids_neg = []
+ text_atts_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+ text_ids_neg.append(encoder_input_ids[neg_idx])
+ text_atts_neg.append(text.attention_mask[neg_idx])
+
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
+
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
+
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
+
+ output_neg = self.text_encoder(text_ids_all,
+ attention_mask = text_atts_all,
+ encoder_hidden_states = image_embeds_all,
+ encoder_attention_mask = image_atts_all,
+ return_dict = True,
+ )
+
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
+ vl_output = self.itm_head(vl_embeddings)
+
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
+ dim=0).to(image.device)
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
+
+ ##================= LM ========================##
+ decoder_input_ids = text.input_ids.clone()
+ decoder_input_ids[:,0] = self.tokenizer.bos_token_id
+ decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
+
+ decoder_output = self.text_decoder(decoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ labels = decoder_targets,
+ return_dict = True,
+ )
+
+ loss_lm = decoder_output.loss
+ return loss_ita, loss_itm, loss_lm
+
+
+
+ @torch.no_grad()
+ def copy_params(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data.copy_(param.data) # initialize
+ param_m.requires_grad = False # not update by gradient
+
+
+ @torch.no_grad()
+ def _momentum_update(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
+
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, image_feat, text_feat):
+ # gather keys before updating queue
+ image_feats = concat_all_gather(image_feat)
+ text_feats = concat_all_gather(text_feat)
+
+ batch_size = image_feats.shape[0]
+
+ ptr = int(self.queue_ptr)
+ assert self.queue_size % batch_size == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
+
+ self.queue_ptr[0] = ptr
+
+
+def blip_pretrain(**kwargs):
+ model = BLIP_Pretrain(**kwargs)
+ return model
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [torch.ones_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+from typing import List
+def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
+ uninitialized_encoder_weights: List[str] = []
+ if decoder.__class__ != encoder.__class__:
+ logger.info(
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
+ )
+
+ def tie_encoder_to_decoder_recursively(
+ decoder_pointer: nn.Module,
+ encoder_pointer: nn.Module,
+ module_name: str,
+ uninitialized_encoder_weights: List[str],
+ skip_key: str,
+ depth=0,
+ ):
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
+ encoder_pointer, nn.Module
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
+ assert hasattr(encoder_pointer, "weight")
+ encoder_pointer.weight = decoder_pointer.weight
+ if hasattr(decoder_pointer, "bias"):
+ assert hasattr(encoder_pointer, "bias")
+ encoder_pointer.bias = decoder_pointer.bias
+ print(module_name+' is tied')
+ return
+
+ encoder_modules = encoder_pointer._modules
+ decoder_modules = decoder_pointer._modules
+ if len(decoder_modules) > 0:
+ assert (
+ len(encoder_modules) > 0
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
+
+ all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
+ encoder_layer_pos = 0
+ for name, module in decoder_modules.items():
+ if name.isdigit():
+ encoder_name = str(int(name) + encoder_layer_pos)
+ decoder_name = name
+ if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
+ encoder_modules
+ ) != len(decoder_modules):
+ # this can happen if the name corresponds to the position in a list module list of layers
+ # in this case the decoder has added a cross-attention that the encoder does not have
+ # thus skip this step and subtract one layer pos from encoder
+ encoder_layer_pos -= 1
+ continue
+ elif name not in encoder_modules:
+ continue
+ elif depth > 500:
+ raise ValueError(
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
+ )
+ else:
+ decoder_name = encoder_name = name
+ tie_encoder_to_decoder_recursively(
+ decoder_modules[decoder_name],
+ encoder_modules[encoder_name],
+ module_name + "/" + name,
+ uninitialized_encoder_weights,
+ skip_key,
+ depth=depth + 1,
+ )
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
+
+ uninitialized_encoder_weights += list(all_encoder_weights)
+
+ # tie weights recursively
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
diff --git a/data/EveryDream/scripts/BLIP/models/blip_retrieval.py b/data/EveryDream/scripts/BLIP/models/blip_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1debe7e2e664f8dd603f8d4c537e3599c68638d7
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/blip_retrieval.py
@@ -0,0 +1,319 @@
+from models.med import BertConfig, BertModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_Retrieval(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 384,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ queue_size = 57600,
+ momentum = 0.995,
+ negative_all_rank = False,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
+ self.itm_head = nn.Linear(text_width, 2)
+
+ # create momentum encoders
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
+ self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
+
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
+ [self.vision_proj,self.vision_proj_m],
+ [self.text_encoder,self.text_encoder_m],
+ [self.text_proj,self.text_proj_m],
+ ]
+ self.copy_params()
+
+ # create the queue
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
+ self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
+
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
+
+ self.queue_size = queue_size
+ self.momentum = momentum
+ self.temp = nn.Parameter(0.07*torch.ones([]))
+
+ self.negative_all_rank = negative_all_rank
+
+
+ def forward(self, image, caption, alpha, idx):
+ with torch.no_grad():
+ self.temp.clamp_(0.001,0.5)
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
+
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
+ return_tensors="pt").to(image.device)
+
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
+
+ ###============== Image-text Contrastive Learning ===================###
+ idx = idx.view(-1,1)
+ idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
+ pos_idx = torch.eq(idx, idx_all).float()
+ sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
+
+ # get momentum features
+ with torch.no_grad():
+ self._momentum_update()
+ image_embeds_m = self.visual_encoder_m(image)
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
+ image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
+
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
+ text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
+
+ sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
+ sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
+
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
+
+ sim_i2t = image_feat @ text_feat_m_all / self.temp
+ sim_t2i = text_feat @ image_feat_m_all / self.temp
+
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
+
+ loss_ita = (loss_i2t+loss_t2i)/2
+
+ idxs = concat_all_gather(idx)
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
+
+ ###============== Image-text Matching ===================###
+ encoder_input_ids = text.input_ids.clone()
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
+
+ # forward the positve image-text pair
+ bs = image.size(0)
+ output_pos = self.text_encoder(encoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+
+
+ if self.negative_all_rank:
+ # compute sample similarity
+ with torch.no_grad():
+ mask = torch.eq(idx, idxs.t())
+
+ image_feat_world = concat_all_gather(image_feat)
+ text_feat_world = concat_all_gather(text_feat)
+
+ sim_i2t = image_feat @ text_feat_world.t() / self.temp
+ sim_t2i = text_feat @ image_feat_world.t() / self.temp
+
+ weights_i2t = F.softmax(sim_i2t,dim=1)
+ weights_i2t.masked_fill_(mask, 0)
+
+ weights_t2i = F.softmax(sim_t2i,dim=1)
+ weights_t2i.masked_fill_(mask, 0)
+
+ image_embeds_world = all_gather_with_grad(image_embeds)
+
+ # select a negative image (from all ranks) for each text
+ image_embeds_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+ image_embeds_neg.append(image_embeds_world[neg_idx])
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
+
+ # select a negative text (from all ranks) for each image
+ input_ids_world = concat_all_gather(encoder_input_ids)
+ att_mask_world = concat_all_gather(text.attention_mask)
+
+ text_ids_neg = []
+ text_atts_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+ text_ids_neg.append(input_ids_world[neg_idx])
+ text_atts_neg.append(att_mask_world[neg_idx])
+
+ else:
+ with torch.no_grad():
+ mask = torch.eq(idx, idx.t())
+
+ sim_i2t = image_feat @ text_feat.t() / self.temp
+ sim_t2i = text_feat @ image_feat.t() / self.temp
+
+ weights_i2t = F.softmax(sim_i2t,dim=1)
+ weights_i2t.masked_fill_(mask, 0)
+
+ weights_t2i = F.softmax(sim_t2i,dim=1)
+ weights_t2i.masked_fill_(mask, 0)
+
+ # select a negative image (from same rank) for each text
+ image_embeds_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+ image_embeds_neg.append(image_embeds[neg_idx])
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
+
+ # select a negative text (from same rank) for each image
+ text_ids_neg = []
+ text_atts_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+ text_ids_neg.append(encoder_input_ids[neg_idx])
+ text_atts_neg.append(text.attention_mask[neg_idx])
+
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
+
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
+
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
+
+ output_neg = self.text_encoder(text_ids_all,
+ attention_mask = text_atts_all,
+ encoder_hidden_states = image_embeds_all,
+ encoder_attention_mask = image_atts_all,
+ return_dict = True,
+ )
+
+
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
+ vl_output = self.itm_head(vl_embeddings)
+
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
+ dim=0).to(image.device)
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
+
+ return loss_ita, loss_itm
+
+
+ @torch.no_grad()
+ def copy_params(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data.copy_(param.data) # initialize
+ param_m.requires_grad = False # not update by gradient
+
+
+ @torch.no_grad()
+ def _momentum_update(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
+
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
+ # gather keys before updating queue
+ image_feats = concat_all_gather(image_feat)
+ text_feats = concat_all_gather(text_feat)
+
+
+ batch_size = image_feats.shape[0]
+
+ ptr = int(self.ptr_queue)
+ assert self.queue_size % batch_size == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
+ self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
+
+ self.ptr_queue[0] = ptr
+
+
+def blip_retrieval(pretrained='',**kwargs):
+ model = BLIP_Retrieval(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ print("missing keys:")
+ print(msg.missing_keys)
+ return model
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [torch.ones_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+class GatherLayer(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ torch.distributed.all_reduce(all_gradients)
+ return all_gradients[torch.distributed.get_rank()]
+
+
+def all_gather_with_grad(tensors):
+ """
+ Performs all_gather operation on the provided tensors.
+ Graph remains connected for backward grad computation.
+ """
+ # Queue the gathered tensors
+ world_size = torch.distributed.get_world_size()
+ # There is no need for reduction in the single-proc case
+ if world_size == 1:
+ return tensors
+
+ tensor_all = GatherLayer.apply(tensors)
+
+ return torch.cat(tensor_all, dim=0)
diff --git a/data/EveryDream/scripts/BLIP/models/blip_vqa.py b/data/EveryDream/scripts/BLIP/models/blip_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4cb3688fad03888f8568ec65437ee20452c6cb8
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/blip_vqa.py
@@ -0,0 +1,186 @@
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import BertTokenizer
+import numpy as np
+
+class BLIP_VQA(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 480,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
+ self.tokenizer = init_tokenizer()
+
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
+
+ decoder_config = BertConfig.from_json_file(med_config)
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
+
+
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
+ return_tensors="pt").to(image.device)
+ question.input_ids[:,0] = self.tokenizer.enc_token_id
+
+ if train:
+ '''
+ n: number of answers for each question
+ weights: weight for each answer
+ '''
+ answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
+ answer.input_ids[:,0] = self.tokenizer.bos_token_id
+ answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
+
+ question_output = self.text_encoder(question.input_ids,
+ attention_mask = question.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True)
+
+ question_states = []
+ question_atts = []
+ for b, n in enumerate(n):
+ question_states += [question_output.last_hidden_state[b]]*n
+ question_atts += [question.attention_mask[b]]*n
+ question_states = torch.stack(question_states,0)
+ question_atts = torch.stack(question_atts,0)
+
+ answer_output = self.text_decoder(answer.input_ids,
+ attention_mask = answer.attention_mask,
+ encoder_hidden_states = question_states,
+ encoder_attention_mask = question_atts,
+ labels = answer_targets,
+ return_dict = True,
+ reduction = 'none',
+ )
+
+ loss = weights * answer_output.loss
+ loss = loss.sum()/image.size(0)
+
+ return loss
+
+
+ else:
+ question_output = self.text_encoder(question.input_ids,
+ attention_mask = question.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True)
+
+ if inference=='generate':
+ num_beams = 3
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
+ question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
+ model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
+
+ bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
+
+ outputs = self.text_decoder.generate(input_ids=bos_ids,
+ max_length=10,
+ min_length=1,
+ num_beams=num_beams,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ **model_kwargs)
+
+ answers = []
+ for output in outputs:
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
+ answers.append(answer)
+ return answers
+
+ elif inference=='rank':
+ max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
+ answer.input_ids, answer.attention_mask, k_test)
+ return max_ids
+
+
+
+ def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
+
+ num_ques = question_states.size(0)
+ start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
+
+ start_output = self.text_decoder(start_ids,
+ encoder_hidden_states = question_states,
+ encoder_attention_mask = question_atts,
+ return_dict = True,
+ reduction = 'none')
+ logits = start_output.logits[:,0,:] # first token's logit
+
+ # topk_probs: top-k probability
+ # topk_ids: [num_question, k]
+ answer_first_token = answer_ids[:,1]
+ prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
+ topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
+
+ # answer input: [num_question*k, answer_len]
+ input_ids = []
+ input_atts = []
+ for b, topk_id in enumerate(topk_ids):
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
+ input_ids = torch.cat(input_ids,dim=0)
+ input_atts = torch.cat(input_atts,dim=0)
+
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
+
+ # repeat encoder's output for top-k answers
+ question_states = tile(question_states, 0, k)
+ question_atts = tile(question_atts, 0, k)
+
+ output = self.text_decoder(input_ids,
+ attention_mask = input_atts,
+ encoder_hidden_states = question_states,
+ encoder_attention_mask = question_atts,
+ labels = targets_ids,
+ return_dict = True,
+ reduction = 'none')
+
+ log_probs_sum = -output.loss
+ log_probs_sum = log_probs_sum.view(num_ques,k)
+
+ max_topk_ids = log_probs_sum.argmax(dim=1)
+ max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
+
+ return max_ids
+
+
+def blip_vqa(pretrained='',**kwargs):
+ model = BLIP_VQA(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+# assert(len(msg.missing_keys)==0)
+ return model
+
+
+def tile(x, dim, n_tile):
+ init_dim = x.size(dim)
+ repeat_idx = [1] * x.dim()
+ repeat_idx[dim] = n_tile
+ x = x.repeat(*(repeat_idx))
+ order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
+ return torch.index_select(x, dim, order_index.to(x.device))
+
+
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/models/med.py b/data/EveryDream/scripts/BLIP/models/med.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b00a35450b736180a805d4f4664b4fb95aeba01
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/med.py
@@ -0,0 +1,955 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+'''
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction='mean',
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ if reduction=='none':
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
diff --git a/data/EveryDream/scripts/BLIP/models/nlvr_encoder.py b/data/EveryDream/scripts/BLIP/models/nlvr_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1946bb4a300f75afa4848f6622839445903c34a9
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/nlvr_encoder.py
@@ -0,0 +1,843 @@
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config, twin=False, merge=False):
+ super().__init__()
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ if twin:
+ self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
+ else:
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if merge:
+ self.act = ACT2FN[config.hidden_act]
+ self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
+ self.merge = True
+ else:
+ self.merge = False
+
+ def forward(self, hidden_states, input_tensor):
+ if type(hidden_states) == list:
+ hidden_states0 = self.dense0(hidden_states[0])
+ hidden_states1 = self.dense1(hidden_states[1])
+ if self.merge:
+ #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
+ hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
+ else:
+ hidden_states = (hidden_states0+hidden_states1)/2
+ else:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_num=-1):
+ super().__init__()
+ if is_cross_attention:
+ self.self0 = BertSelfAttention(config, is_cross_attention)
+ self.self1 = BertSelfAttention(config, is_cross_attention)
+ else:
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ if type(encoder_hidden_states)==list:
+ self_outputs0 = self.self0(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states[0],
+ encoder_attention_mask[0],
+ past_key_value,
+ output_attentions,
+ )
+ self_outputs1 = self.self1(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states[1],
+ encoder_attention_mask[1],
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
+
+ outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
+ else:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
diff --git a/data/EveryDream/scripts/BLIP/models/vit.py b/data/EveryDream/scripts/BLIP/models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/models/vit.py
@@ -0,0 +1,305 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on timm code base
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.attn_gradients = None
+ self.attention_map = None
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def forward(self, x, register_hook=False):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ if register_hook:
+ self.save_attention_map(attn)
+ attn.register_hook(self.save_attn_gradients)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if use_grad_checkpointing:
+ self.attn = checkpoint_wrapper(self.attn)
+ self.mlp = checkpoint_wrapper(self.mlp)
+
+ def forward(self, x, register_hook=False):
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
+ use_grad_checkpointing=False, ckpt_layer=0):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer: (nn.Module): normalization layer
+ """
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
+ )
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward(self, x, register_blk=-1):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:,:x.size(1),:]
+ x = self.pos_drop(x)
+
+ for i,blk in enumerate(self.blocks):
+ x = blk(x, register_blk==i)
+ x = self.norm(x)
+
+ return x
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=''):
+ _load_weights(self, checkpoint_path, prefix)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+ """
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and 'opt/target/embedding/kernel' in w:
+ prefix = 'opt/target/'
+
+ if hasattr(model.patch_embed, 'backbone'):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, 'stem')
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+ for r in range(3):
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+ block.attn.qkv.weight.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+ block.attn.qkv.bias.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+ for r in range(2):
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
+ # interpolate position embedding
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = visual_encoder.patch_embed.num_patches
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+
+ if orig_size!=new_size:
+ # class_token and dist_token are kept unchanged
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
+
+ return new_pos_embed
+ else:
+ return pos_embed_checkpoint
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/predict.py b/data/EveryDream/scripts/BLIP/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..35426cadcbb3bf8c3d8cb9c910511c154e451f4e
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/predict.py
@@ -0,0 +1,98 @@
+"""
+Download the weights in ./checkpoints beforehand for fast inference
+wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
+wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
+wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
+"""
+
+from pathlib import Path
+
+from PIL import Image
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import cog
+
+from models.blip import blip_decoder
+from models.blip_vqa import blip_vqa
+from models.blip_itm import blip_itm
+
+
+class Predictor(cog.Predictor):
+ def setup(self):
+ self.device = "cuda:0"
+
+ self.models = {
+ 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
+ image_size=384, vit='base'),
+ 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
+ image_size=480, vit='base'),
+ 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
+ image_size=384, vit='base')
+ }
+
+ @cog.input(
+ "image",
+ type=Path,
+ help="input image",
+ )
+ @cog.input(
+ "task",
+ type=str,
+ default='image_captioning',
+ options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
+ help="Choose a task.",
+ )
+ @cog.input(
+ "question",
+ type=str,
+ default=None,
+ help="Type question for the input image for visual question answering task.",
+ )
+ @cog.input(
+ "caption",
+ type=str,
+ default=None,
+ help="Type caption for the input image for image text matching task.",
+ )
+ def predict(self, image, task, question, caption):
+ if task == 'visual_question_answering':
+ assert question is not None, 'Please type a question for visual question answering task.'
+ if task == 'image_text_matching':
+ assert caption is not None, 'Please type a caption for mage text matching task.'
+
+ im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
+ model = self.models[task]
+ model.eval()
+ model = model.to(self.device)
+
+ if task == 'image_captioning':
+ with torch.no_grad():
+ caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
+ return 'Caption: ' + caption[0]
+
+ if task == 'visual_question_answering':
+ with torch.no_grad():
+ answer = model(im, question, train=False, inference='generate')
+ return 'Answer: ' + answer[0]
+
+ # image_text_matching
+ itm_output = model(im, caption, match_head='itm')
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
+ itc_score = model(im, caption, match_head='itc')
+ return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
+ f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'
+
+
+def load_image(image, image_size, device):
+ raw_image = Image.open(str(image)).convert('RGB')
+
+ w, h = raw_image.size
+
+ transform = transforms.Compose([
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])
+ image = transform(raw_image).unsqueeze(0).to(device)
+ return image
diff --git a/data/EveryDream/scripts/BLIP/pretrain.py b/data/EveryDream/scripts/BLIP/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9490ec8eb8ff5f074b5772ada55cd27ec673a12
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/pretrain.py
@@ -0,0 +1,173 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip_pretrain import blip_pretrain
+import utils
+from utils import warmup_lr_schedule, step_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+
+def train(model, data_loader, optimizer, epoch, device, config):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+ metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ if config['laion_path']:
+ data_loader.dataset.reload_laion(epoch)
+
+ data_loader.sampler.set_epoch(epoch)
+
+ for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+ if epoch==0:
+ warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
+
+ optimizer.zero_grad()
+
+ image = image.to(device,non_blocking=True)
+
+ # ramp up alpha in the first 2 epochs
+ alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
+
+ loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
+ loss = loss_ita + loss_itm + loss_lm
+
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss_ita=loss_ita.item())
+ metric_logger.update(loss_itm=loss_itm.item())
+ metric_logger.update(loss_lm=loss_lm.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating dataset")
+ datasets = [create_dataset('pretrain', config, min_scale=0.2)]
+ print('number of training samples: %d'%len(datasets[0]))
+
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler(datasets, [True], num_tasks, global_rank)
+
+ data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
+
+ #### Model ####
+ print("Creating model")
+ model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
+ vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
+
+ model = model.to(device)
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ start_epoch = 0
+ if args.checkpoint:
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ state_dict = checkpoint['model']
+ model.load_state_dict(state_dict)
+
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ start_epoch = checkpoint['epoch']+1
+ print('resume checkpoint from %s'%args.checkpoint)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(start_epoch, config['max_epoch']):
+
+ step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
+
+ train_stats = train(model, data_loader, optimizer, epoch, device, config)
+ if utils.is_main_process():
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ }
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
+
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/pretrain.yaml')
+ parser.add_argument('--output_dir', default='output/Pretrain')
+ parser.add_argument('--checkpoint', default='')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/requirements.txt b/data/EveryDream/scripts/BLIP/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d897bc6a08712f4beb2f78ca2592dcbe06a3e2db
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/requirements.txt
@@ -0,0 +1,4 @@
+timm==0.4.12
+transformers==4.15.0
+fairscale==0.4.4
+pycocoevalcap
diff --git a/data/EveryDream/scripts/BLIP/train_caption.py b/data/EveryDream/scripts/BLIP/train_caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c639ac646b9a1b8074b6e9c2343b961de76db05
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/train_caption.py
@@ -0,0 +1,206 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip import blip_decoder
+import utils
+from utils import cosine_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+from data.utils import save_result, coco_caption_eval
+
+def train(model, data_loader, optimizer, epoch, device):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+ header = 'Train Caption Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image = image.to(device)
+
+ loss = model(image, caption)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss=loss.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, config):
+ # evaluate
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Caption generation:'
+ print_freq = 10
+
+ result = []
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
+
+ image = image.to(device)
+
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
+ min_length=config['min_length'])
+
+ for caption, img_id in zip(captions, image_id):
+ result.append({"image_id": img_id.item(), "caption": caption})
+
+ return result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating captioning dataset")
+ train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
+ else:
+ samplers = [None, None, None]
+
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
+ batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
+ is_trains=[True, False, False], collate_fns=[None,None,None])
+
+ #### Model ####
+ print("Creating model")
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
+ prompt=config['prompt'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ best = 0
+ best_epoch = 0
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device)
+
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
+ val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
+
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
+ test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
+
+ if utils.is_main_process():
+ coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
+ coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
+
+ if args.evaluate:
+ log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
+ }
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ else:
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+
+ if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
+ best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
+ best_epoch = epoch
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in coco_val.eval.items()},
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
+ 'epoch': epoch,
+ 'best_epoch': best_epoch,
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ if args.evaluate:
+ break
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/caption_coco.yaml')
+ parser.add_argument('--output_dir', default='output/Caption_coco')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ args.result_dir = os.path.join(args.output_dir, 'result')
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/train_nlvr.py b/data/EveryDream/scripts/BLIP/train_nlvr.py
new file mode 100644
index 0000000000000000000000000000000000000000..84b247bda2334c1fd894b6c11d33ef48c8e7df28
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/train_nlvr.py
@@ -0,0 +1,213 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+import json
+import pickle
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+from models.blip_nlvr import blip_nlvr
+
+import utils
+from utils import cosine_lr_schedule, warmup_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+
+def train(model, data_loader, optimizer, epoch, device, config):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+ step_size = 10
+
+ for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+ images = torch.cat([image0, image1], dim=0)
+ images, targets = images.to(device), targets.to(device)
+
+ loss = model(images, text, targets=targets, train=True)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(loss=loss.item())
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, config):
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+
+ header = 'Evaluation:'
+ print_freq = 50
+
+ for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
+ images = torch.cat([image0, image1], dim=0)
+ images, targets = images.to(device), targets.to(device)
+
+ prediction = model(images, text, targets=targets, train=False)
+
+ _, pred_class = prediction.max(1)
+ accuracy = (targets==pred_class).sum() / targets.size(0)
+
+ metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating dataset")
+ datasets = create_dataset('nlvr', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
+ else:
+ samplers = [None, None, None]
+
+ batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
+ train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
+ num_workers=[4,4,4],is_trains=[True,False,False],
+ collate_fns=[None,None,None])
+
+ #### Model ####
+ print("Creating model")
+ model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ print("Start training")
+ start_time = time.time()
+ best = 0
+ best_epoch = 0
+
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
+
+ val_stats = evaluate(model, val_loader, device, config)
+ test_stats = evaluate(model, test_loader, device, config)
+
+ if utils.is_main_process():
+ if args.evaluate:
+ log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in val_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ }
+
+ if float(val_stats['acc'])>best:
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
+ best = float(val_stats['acc'])
+ best_epoch = epoch
+
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ if args.evaluate:
+ break
+
+ dist.barrier()
+
+ if utils.is_main_process():
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write("best epoch: %d"%best_epoch)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/nlvr.yaml')
+ parser.add_argument('--output_dir', default='output/NLVR')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/train_retrieval.py b/data/EveryDream/scripts/BLIP/train_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..574f03382cc8197b97971a11ae54b632bcfe6655
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/train_retrieval.py
@@ -0,0 +1,345 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip_retrieval import blip_retrieval
+import utils
+from utils import cosine_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+
+
+def train(model, data_loader, optimizer, epoch, device, config):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image = image.to(device,non_blocking=True)
+ idx = idx.to(device,non_blocking=True)
+
+ if epoch>0:
+ alpha = config['alpha']
+ else:
+ alpha = config['alpha']*min(1,i/len(data_loader))
+
+ loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
+ loss = loss_ita + loss_itm
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss_itm=loss_itm.item())
+ metric_logger.update(loss_ita=loss_ita.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, device, config):
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Evaluation:'
+
+ print('Computing features for evaluation...')
+ start_time = time.time()
+
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_ids = []
+ text_embeds = []
+ text_atts = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i: min(num_text, i+text_bs)]
+ text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
+ text_embeds.append(text_embed)
+ text_ids.append(text_input.input_ids)
+ text_atts.append(text_input.attention_mask)
+
+ text_embeds = torch.cat(text_embeds,dim=0)
+ text_ids = torch.cat(text_ids,dim=0)
+ text_atts = torch.cat(text_atts,dim=0)
+ text_ids[:,0] = model.tokenizer.enc_token_id
+
+ image_feats = []
+ image_embeds = []
+ for image, img_id in data_loader:
+ image = image.to(device)
+ image_feat = model.visual_encoder(image)
+ image_embed = model.vision_proj(image_feat[:,0,:])
+ image_embed = F.normalize(image_embed,dim=-1)
+
+ image_feats.append(image_feat.cpu())
+ image_embeds.append(image_embed)
+
+ image_feats = torch.cat(image_feats,dim=0)
+ image_embeds = torch.cat(image_embeds,dim=0)
+
+ sims_matrix = image_embeds @ text_embeds.t()
+ score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
+
+ num_tasks = utils.get_world_size()
+ rank = utils.get_rank()
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+
+ encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
+ output = model.text_encoder(text_ids[topk_idx],
+ attention_mask = text_atts[topk_idx],
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_i2t[start+i,topk_idx] = score + topk_sim
+
+ sims_matrix = sims_matrix.t()
+ score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
+
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+ encoder_output = image_feats[topk_idx].to(device)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_t2i[start+i,topk_idx] = score + topk_sim
+
+ if args.distributed:
+ dist.barrier()
+ torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
+ torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Evaluation time {}'.format(total_time_str))
+
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
+
+
+
+@torch.no_grad()
+def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
+
+ #Images->Text
+ ranks = np.zeros(scores_i2t.shape[0])
+ for index,score in enumerate(scores_i2t):
+ inds = np.argsort(score)[::-1]
+ # Score
+ rank = 1e20
+ for i in img2txt[index]:
+ tmp = np.where(inds == i)[0][0]
+ if tmp < rank:
+ rank = tmp
+ ranks[index] = rank
+
+ # Compute metrics
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ #Text->Images
+ ranks = np.zeros(scores_t2i.shape[0])
+
+ for index,score in enumerate(scores_t2i):
+ inds = np.argsort(score)[::-1]
+ ranks[index] = np.where(inds == txt2img[index])[0][0]
+
+ # Compute metrics
+ ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ tr_mean = (tr1 + tr5 + tr10) / 3
+ ir_mean = (ir1 + ir5 + ir10) / 3
+ r_mean = (tr_mean + ir_mean) / 2
+
+ eval_result = {'txt_r1': tr1,
+ 'txt_r5': tr5,
+ 'txt_r10': tr10,
+ 'txt_r_mean': tr_mean,
+ 'img_r1': ir1,
+ 'img_r5': ir5,
+ 'img_r10': ir10,
+ 'img_r_mean': ir_mean,
+ 'r_mean': r_mean}
+ return eval_result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating retrieval dataset")
+ train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
+ else:
+ samplers = [None, None, None]
+
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
+ batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
+ num_workers=[4,4,4],
+ is_trains=[True, False, False],
+ collate_fns=[None,None,None])
+
+
+ #### Model ####
+ print("Creating model")
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
+ queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ best = 0
+ best_epoch = 0
+
+ print("Start training")
+ start_time = time.time()
+
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
+
+ score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
+ score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
+
+ if utils.is_main_process():
+
+ val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
+ print(val_result)
+
+ if val_result['r_mean']>best:
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
+ best = val_result['r_mean']
+ best_epoch = epoch
+
+ test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
+ print(test_result)
+
+ if args.evaluate:
+ log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
+ **{f'test_{k}': v for k, v in test_result.items()},
+ }
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in val_result.items()},
+ **{f'test_{k}': v for k, v in test_result.items()},
+ 'epoch': epoch,
+ 'best_epoch': best_epoch,
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ if args.evaluate:
+ break
+
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
+ parser.add_argument('--output_dir', default='output/Retrieval_flickr')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/train_vqa.py b/data/EveryDream/scripts/BLIP/train_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..89eb7490862e517cc660f842396033c21d441a20
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/train_vqa.py
@@ -0,0 +1,202 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+from models.blip_vqa import blip_vqa
+import utils
+from utils import cosine_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+from data.vqa_dataset import vqa_collate_fn
+from data.utils import save_result
+
+
+def train(model, data_loader, optimizer, epoch, device):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
+
+ loss = model(image, question, answer, train=True, n=n, weights=weights)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss=loss.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, device, config) :
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Generate VQA test result:'
+ print_freq = 50
+
+ result = []
+
+ if config['inference']=='rank':
+ answer_list = data_loader.dataset.answer_list
+ answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
+ answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
+
+ for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image = image.to(device,non_blocking=True)
+
+ if config['inference']=='generate':
+ answers = model(image, question, train=False, inference='generate')
+
+ for answer, ques_id in zip(answers, question_id):
+ ques_id = int(ques_id.item())
+ result.append({"question_id":ques_id, "answer":answer})
+
+ elif config['inference']=='rank':
+ answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
+
+ for ques_id, answer_id in zip(question_id, answer_ids):
+ result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
+
+ return result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating vqa datasets")
+ datasets = create_dataset('vqa', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
+ else:
+ samplers = [None, None]
+
+ train_loader, test_loader = create_loader(datasets,samplers,
+ batch_size=[config['batch_size_train'],config['batch_size_test']],
+ num_workers=[4,4],is_trains=[True, False],
+ collate_fns=[vqa_collate_fn,None])
+ #### Model ####
+ print("Creating model")
+ model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ best = 0
+ best_epoch = 0
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device)
+
+ else:
+ break
+
+ if utils.is_main_process():
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
+
+ dist.barrier()
+
+ vqa_result = evaluation(model_without_ddp, test_loader, device, config)
+ result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/vqa.yaml')
+ parser.add_argument('--output_dir', default='output/VQA')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ args.result_dir = os.path.join(args.output_dir, 'result')
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/transform/randaugment.py b/data/EveryDream/scripts/BLIP/transform/randaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..094d9f4cacc93146d2bab7311d9dc04feb07032c
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/transform/randaugment.py
@@ -0,0 +1,340 @@
+import cv2
+import numpy as np
+
+
+## aug functions
+def identity_func(img):
+ return img
+
+
+def autocontrast_func(img, cutoff=0):
+ '''
+ same output as PIL.ImageOps.autocontrast
+ '''
+ n_bins = 256
+
+ def tune_channel(ch):
+ n = ch.size
+ cut = cutoff * n // 100
+ if cut == 0:
+ high, low = ch.max(), ch.min()
+ else:
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ low = np.argwhere(np.cumsum(hist) > cut)
+ low = 0 if low.shape[0] == 0 else low[0]
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
+ if high <= low:
+ table = np.arange(n_bins)
+ else:
+ scale = (n_bins - 1) / (high - low)
+ offset = -low * scale
+ table = np.arange(n_bins) * scale + offset
+ table[table < 0] = 0
+ table[table > n_bins - 1] = n_bins - 1
+ table = table.clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def equalize_func(img):
+ '''
+ same output as PIL.ImageOps.equalize
+ PIL's implementation is different from cv2.equalize
+ '''
+ n_bins = 256
+
+ def tune_channel(ch):
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ non_zero_hist = hist[hist != 0].reshape(-1)
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
+ if step == 0: return ch
+ n = np.empty_like(hist)
+ n[0] = step // 2
+ n[1:] = hist[:-1]
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def rotate_func(img, degree, fill=(0, 0, 0)):
+ '''
+ like PIL, rotate by degree, not radians
+ '''
+ H, W = img.shape[0], img.shape[1]
+ center = W / 2, H / 2
+ M = cv2.getRotationMatrix2D(center, degree, 1)
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
+ return out
+
+
+def solarize_func(img, thresh=128):
+ '''
+ same output as PIL.ImageOps.posterize
+ '''
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
+ table = table.clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def color_func(img, factor):
+ '''
+ same output as PIL.ImageEnhance.Color
+ '''
+ ## implementation according to PIL definition, quite slow
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
+ # out = blend(degenerate, img, factor)
+ # M = (
+ # np.eye(3) * factor
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
+ # )[np.newaxis, np.newaxis, :]
+ M = (
+ np.float32([
+ [0.886, -0.114, -0.114],
+ [-0.587, 0.413, -0.587],
+ [-0.299, -0.299, 0.701]]) * factor
+ + np.float32([[0.114], [0.587], [0.299]])
+ )
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
+ return out
+
+
+def contrast_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
+ table = np.array([(
+ el - mean) * factor + mean
+ for el in range(256)
+ ]).clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def brightness_func(img, factor):
+ '''
+ same output as PIL.ImageEnhance.Contrast
+ '''
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def sharpness_func(img, factor):
+ '''
+ The differences the this result and PIL are all on the 4 boundaries, the center
+ areas are same
+ '''
+ kernel = np.ones((3, 3), dtype=np.float32)
+ kernel[1][1] = 5
+ kernel /= 13
+ degenerate = cv2.filter2D(img, -1, kernel)
+ if factor == 0.0:
+ out = degenerate
+ elif factor == 1.0:
+ out = img
+ else:
+ out = img.astype(np.float32)
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
+ out = out.astype(np.uint8)
+ return out
+
+
+def shear_x_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def translate_x_func(img, offset, fill=(0, 0, 0)):
+ '''
+ same output as PIL.Image.transform
+ '''
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def translate_y_func(img, offset, fill=(0, 0, 0)):
+ '''
+ same output as PIL.Image.transform
+ '''
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def posterize_func(img, bits):
+ '''
+ same output as PIL.ImageOps.posterize
+ '''
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
+ return out
+
+
+def shear_y_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def cutout_func(img, pad_size, replace=(0, 0, 0)):
+ replace = np.array(replace, dtype=np.uint8)
+ H, W = img.shape[0], img.shape[1]
+ rh, rw = np.random.random(2)
+ pad_size = pad_size // 2
+ ch, cw = int(rh * H), int(rw * W)
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
+ out = img.copy()
+ out[x1:x2, y1:y2, :] = replace
+ return out
+
+
+### level to args
+def enhance_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
+ return level_to_args
+
+
+def shear_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 0.3
+ if np.random.random() > 0.5: level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * float(translate_const)
+ if np.random.random() > 0.5: level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * cutout_const)
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def solarize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 256)
+ return (level, )
+ return level_to_args
+
+
+def none_level_to_args(level):
+ return ()
+
+
+def posterize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 4)
+ return (level, )
+ return level_to_args
+
+
+def rotate_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 30
+ if np.random.random() < 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+func_dict = {
+ 'Identity': identity_func,
+ 'AutoContrast': autocontrast_func,
+ 'Equalize': equalize_func,
+ 'Rotate': rotate_func,
+ 'Solarize': solarize_func,
+ 'Color': color_func,
+ 'Contrast': contrast_func,
+ 'Brightness': brightness_func,
+ 'Sharpness': sharpness_func,
+ 'ShearX': shear_x_func,
+ 'TranslateX': translate_x_func,
+ 'TranslateY': translate_y_func,
+ 'Posterize': posterize_func,
+ 'ShearY': shear_y_func,
+}
+
+translate_const = 10
+MAX_LEVEL = 10
+replace_value = (128, 128, 128)
+arg_dict = {
+ 'Identity': none_level_to_args,
+ 'AutoContrast': none_level_to_args,
+ 'Equalize': none_level_to_args,
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
+ 'Color': enhance_level_to_args(MAX_LEVEL),
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
+ 'TranslateX': translate_level_to_args(
+ translate_const, MAX_LEVEL, replace_value
+ ),
+ 'TranslateY': translate_level_to_args(
+ translate_const, MAX_LEVEL, replace_value
+ ),
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
+}
+
+
+class RandomAugment(object):
+
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
+ self.N = N
+ self.M = M
+ self.isPIL = isPIL
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N)
+ return [(op, 0.5, self.M) for op in sampled_ops]
+
+ def __call__(self, img):
+ if self.isPIL:
+ img = np.array(img)
+ ops = self.get_random_ops()
+ for name, prob, level in ops:
+ if np.random.random() > prob:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return img
+
+
+if __name__ == '__main__':
+ a = RandomAugment()
+ img = np.random.randn(32, 32, 3)
+ a(img)
\ No newline at end of file
diff --git a/data/EveryDream/scripts/BLIP/utils.py b/data/EveryDream/scripts/BLIP/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe0e1dc2f5d200156d5dd1acc305a8b7b7b98da
--- /dev/null
+++ b/data/EveryDream/scripts/BLIP/utils.py
@@ -0,0 +1,278 @@
+import math
+def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
+ """Decay the learning rate"""
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
+ """Warmup the learning rate"""
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
+ """Decay the learning rate"""
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+import numpy as np
+import io
+import os
+import time
+from collections import defaultdict, deque
+import datetime
+
+import torch
+import torch.distributed as dist
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {:.4f}".format(name, meter.global_avg)
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def compute_acc(logits, label, reduction='mean'):
+ ret = (torch.argmax(logits, dim=1) == label).float()
+ if reduction == 'none':
+ return ret.detach()
+ elif reduction == 'mean':
+ return ret.mean().item()
+
+def compute_n_params(model, return_str=True):
+ tot = 0
+ for p in model.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return '{:.1f}M'.format(tot / 1e6)
+ else:
+ return '{:.1f}K'.format(tot / 1e3)
+ else:
+ return tot
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}, word {}): {}'.format(
+ args.rank, args.world_size, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
\ No newline at end of file
diff --git a/data/EveryDream/scripts/auto_caption.py b/data/EveryDream/scripts/auto_caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c2f8eb6c84a290f20b76600995c5788ec6cefd
--- /dev/null
+++ b/data/EveryDream/scripts/auto_caption.py
@@ -0,0 +1,217 @@
+import argparse
+import glob
+import os
+from PIL import Image
+import sys
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import torch
+import aiohttp
+import asyncio
+import subprocess
+import numpy as np
+import io
+import aiofiles
+
+SIZE = 384
+BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
+
+def get_parser(**parser_kwargs):
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument(
+ "--img_dir",
+ type=str,
+ nargs="?",
+ const=True,
+ default="input",
+ help="directory with images to be captioned",
+ ),
+ parser.add_argument(
+ "--out_dir",
+ type=str,
+ nargs="?",
+ const=True,
+ default="output",
+ help="directory to put captioned images",
+ ),
+ parser.add_argument(
+ "--format",
+ type=str,
+ nargs="?",
+ const=True,
+ default="filename",
+ help="'filename', 'mrwho', 'txt', or 'caption'",
+ ),
+ parser.add_argument(
+ "--nucleus",
+ type=bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="use nucleus sampling instead of beam",
+ ),
+ parser.add_argument(
+ "--q_factor",
+ type=float,
+ nargs="?",
+ const=True,
+ default=1.0,
+ help="adjusts the likelihood of a word being repeated",
+ ),
+ parser.add_argument(
+ "--min_length",
+ type=int,
+ nargs="?",
+ const=True,
+ default=22,
+ help="adjusts the likelihood of a word being repeated",
+ ),
+ parser.add_argument(
+ "--torch_device",
+ type=str,
+ nargs="?",
+ const=False,
+ default="cuda",
+ help="specify a different torch device, e.g. 'cpu'",
+ ),
+
+ return parser
+
+def load_image(raw_image, device):
+ transform = transforms.Compose([
+ #transforms.CenterCrop(SIZE),
+ transforms.Resize((SIZE, SIZE), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
+ ])
+ image = transform(raw_image).unsqueeze(0).to(device)
+ return image
+
+def get_out_file_name(out_dir, base_name, ext):
+ return os.path.join(out_dir, f"{base_name}{ext}")
+
+async def main(opt):
+ print("starting")
+ import models.blip
+
+ sample = False
+ if opt.nucleus:
+ sample = True
+
+ input_dir = opt.img_dir
+ print("input_dir: ", input_dir)
+
+ config_path = "scripts/BLIP/configs/med_config.json"
+
+ cache_folder = ".cache"
+ model_cache_path = ".cache/model_base_caption_capfilt_large.pth"
+
+ if not os.path.exists(cache_folder):
+ os.makedirs(cache_folder)
+
+ if not os.path.exists(opt.out_dir):
+ os.makedirs(opt.out_dir)
+
+ if not os.path.exists(model_cache_path):
+ print(f"Downloading model to {model_cache_path}... please wait")
+
+ async with aiohttp.ClientSession() as session:
+ async with session.get(BLIP_MODEL_URL) as res:
+ with open(model_cache_path, 'wb') as f:
+ async for chunk in res.content.iter_chunked(1024):
+ f.write(chunk)
+ print(f"Model cached to: {model_cache_path}")
+ else:
+ print(f"Model already cached to: {model_cache_path}")
+
+ blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
+ blip_decoder.eval()
+
+ print(f"loading model to {opt.torch_device}")
+
+ blip_decoder = blip_decoder.to(torch.device(opt.torch_device))
+
+ ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif')
+
+ i = 0
+
+ for idx, img_file_name in enumerate(glob.iglob(os.path.join(opt.img_dir, "*.*"))):
+ if img_file_name.endswith(ext):
+ caption = None
+ file_ext = os.path.splitext(img_file_name)[1]
+ if (file_ext in ext):
+ async with aiofiles.open(img_file_name, "rb") as input_file:
+ print("working image: ", img_file_name)
+
+ image_bin = await input_file.read()
+ image = Image.open(io.BytesIO(image_bin))
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ image = load_image(image, device=torch.device(opt.torch_device))
+
+ if opt.nucleus:
+ captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor)
+ else:
+ captions = blip_decoder.generate(image, sample=sample, num_beams=16, min_length=opt.min_length, \
+ max_length=48, repetition_penalty=opt.q_factor)
+
+ caption = captions[0]
+
+ if opt.format in ["mrwho","joepenna"]:
+ prefix = f"{i:05}@"
+ i += 1
+ caption = prefix+caption
+ elif opt.format == "filename":
+ postfix = f"_{i}"
+ i += 1
+ caption = caption+postfix
+
+ if opt.format in ["txt","text","caption"]:
+ out_base_name = os.path.splitext(os.path.basename(img_file_name))[0]
+
+ if opt.format in ["txt","text"]:
+ out_file = get_out_file_name(opt.out_dir, out_base_name, ".txt")
+
+ if opt.format in ["caption"]:
+ out_file = get_out_file_name(opt.out_dir, out_base_name, ".caption")
+
+ if opt.format in ["txt","text","caption"]:
+ print("writing caption to: ", out_file)
+ async with aiofiles.open(out_file, "w") as out_file:
+ await out_file.write(caption)
+
+ if opt.format in ["filename", "mrwho", "joepenna"]:
+ caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename
+ out_file = get_out_file_name(opt.out_dir, caption, file_ext)
+ async with aiofiles.open(out_file, "wb") as out_file:
+ await out_file.write(image_bin)
+ elif opt.format == "json":
+ raise NotImplementedError
+ elif opt.format == "parquet":
+ raise NotImplementedError
+
+def isWindows():
+ return sys.platform.startswith("win")
+
+if __name__ == "__main__":
+ parser = get_parser()
+ opt = parser.parse_args()
+
+ if opt.format not in ["filename", "mrwho", "joepenna", "txt", "text", "caption"]:
+ raise ValueError("format must be 'filename', 'mrwho', 'txt', or 'caption'")
+
+ if (isWindows()):
+ print("Windows detected, using asyncio.WindowsSelectorEventLoopPolicy")
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+ else:
+ print("Unix detected, using default asyncio event loop policy")
+
+ if not os.path.exists("scripts/BLIP"):
+ print("BLIP not found, cloning BLIP repo")
+ subprocess.run(["git", "clone", "https://github.com/salesforce/BLIP", "scripts/BLIP"])
+ blip_path = "scripts/BLIP"
+ sys.path.append(blip_path)
+
+ asyncio.run(main(opt))
diff --git a/data/EveryDream/scripts/compress_img.py b/data/EveryDream/scripts/compress_img.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1a264134a90daf9f83772f0b49f215a44e3afae
--- /dev/null
+++ b/data/EveryDream/scripts/compress_img.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python3
+
+"""Compress images in a folder to a maximum megapixel size."""
+
+import argparse
+import asyncio
+import os
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from glob import iglob
+from multiprocessing import cpu_count
+from queue import Queue
+
+from PIL import Image, ImageFile, ImageOps
+
+# Prevent errors from halting the script.
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+Image.warnings.simplefilter("error", Image.DecompressionBombWarning)
+
+VERSION = "2.0"
+SHORT_DESCRIPTION = "Compress images in a directory."
+SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp"]
+
+
+def get_args(**parser_kwargs):
+ """Get command-line options."""
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument(
+ "--img_dir",
+ type=str,
+ default="input",
+ help="path to image directory (default: 'input')",
+ )
+ parser.add_argument(
+ "--out_dir",
+ type=str,
+ default=None,
+ help="path to output directory (default: IMG_DIR)",
+ )
+ parser.add_argument(
+ "--max_mp",
+ type=float,
+ default=1.5,
+ help="maximum megapixels (default: 1.5)",
+ )
+ parser.add_argument(
+ "--quality",
+ type=int,
+ default=95,
+ help="save quality (default: 95, range: 0-100, suggested: 90+)",
+ )
+ parser.add_argument(
+ "--overwrite",
+ action="store_true",
+ default=False,
+ help="overwrite files in output directory",
+ )
+ parser.add_argument(
+ "--noresize",
+ action="store_true",
+ default=False,
+ help="do not resize, just fix orientation",
+ )
+ parser.add_argument(
+ "--delete",
+ action="store_true",
+ default=False,
+ help="delete original files after processing",
+ )
+ args = parser.parse_args()
+ args.out_dir = args.out_dir or args.img_dir
+ args.max_mp = args.max_mp * 1024000
+ return args
+
+
+def images(img_dir):
+ """Return each image in the input directory."""
+ for file in iglob(f"{img_dir}/*.*"):
+ if file.lower().endswith(tuple(SUPPORTED_EXTENSIONS)):
+ yield file
+
+
+def inline(msg, newline=False):
+ """Print a message on the same line."""
+ msg = f"\r{msg}"
+ msg += " " * (79 - len(msg))
+ print(msg, end="\n" if newline else "", flush=True)
+
+
+def launch_workers(queue, args):
+ """Launch a pool of workers."""
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ tasks = [loop.create_task(worker(queue, args)) for _ in range(10)]
+ loop.run_until_complete(asyncio.wait(tasks))
+
+
+async def open_img(path):
+ """Open an image."""
+ loop = asyncio.get_running_loop()
+ try:
+ return await loop.run_in_executor(None, Image.open, path)
+ except Exception as err:
+ inline(f"[!] Error Opening: {path} - {err}", True)
+ return None
+
+
+def oversize(img, max_mp):
+ """Check if an image is larger than the maximum size."""
+ return (img.width * img.height) > max_mp
+
+
+async def process(image, args):
+ """Process an image."""
+ outfile = image.replace(args.img_dir, args.out_dir).replace(
+ os.path.splitext(image)[1], ".webp"
+ )
+ if args.overwrite or not os.path.exists(outfile):
+ img = await open_img(image)
+ if img:
+ newimg = transpose(img)
+ if not args.noresize and oversize(newimg, args.max_mp):
+ newimg = shrink(newimg, args)
+ if newimg != img:
+ await save_img(newimg, outfile, args)
+ if args.delete and outfile != image:
+ os.remove(image)
+
+
+def slow_save(path, args, img):
+ """Save an image."""
+ try:
+ img.save(path, "webp", quality=args.quality)
+ inline(f"[+] Compressed: {path}")
+ except Exception as err:
+ inline(f"[!] Error Saving: {path} - {err}", True)
+
+
+async def save_img(img, path, args):
+ """Save an image."""
+ loop = asyncio.get_running_loop()
+ await loop.run_in_executor(None, slow_save, path, args, img)
+
+
+def scan_path(queue, args):
+ """Scan the input directory for images."""
+ inline("[*] Scanning for images...", True)
+ for image in images(args.img_dir):
+ inline(f"[+] {image}")
+ queue.put(image)
+
+
+def shrink(img, args):
+ """Shrink an image."""
+ hw = img.size
+ ratio = args.max_mp / (hw[0]*hw[1])
+ newhw = (int(hw[0]*ratio**0.5), int(hw[1]*ratio**0.5))
+
+ try:
+ return img.resize(newhw, Image.BICUBIC)
+ except Exception as err:
+ inline(f"[!] Error Shrinking: {img.filename} - {err}", True)
+ return img
+
+
+def start_compression(queue, args):
+ """Start the compression process."""
+ inline("[*] Compressing images...", True)
+ inline("[-] (scanning...)")
+ with ThreadPoolExecutor() as executor:
+ workers = {
+ executor.submit(launch_workers, queue, args): None
+ for _ in range(cpu_count())
+ }
+ for _ in as_completed(workers):
+ pass
+ inline("[!] Done!", True)
+
+
+def transpose(img):
+ """Transpose an image."""
+ try:
+ return ImageOps.exif_transpose(img)
+ except Exception as err:
+ inline(f"[!] Error Transposing: {img.filename} - {err}", True)
+ return img
+
+
+async def worker(queue, args):
+ """Handle images from the queue until they're gone."""
+ while not queue.empty():
+ image = queue.get()
+ await process(image, args)
+
+
+def main():
+ """Run the program."""
+ queue = Queue()
+ args = get_args(description=SHORT_DESCRIPTION)
+ inline(f"[>] Image Compression Utility v{VERSION}", True)
+ scan_path(queue, args)
+ start_compression(queue, args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/data/EveryDream/scripts/createtxtfromfilename.py b/data/EveryDream/scripts/createtxtfromfilename.py
new file mode 100644
index 0000000000000000000000000000000000000000..e420d9fb21f061e996d056b14f8aa42eaba612b0
--- /dev/null
+++ b/data/EveryDream/scripts/createtxtfromfilename.py
@@ -0,0 +1,27 @@
+import glob
+import os
+import argparse
+
+def create_txt_from_filename(path):
+ """
+ create a .txt file for each file in the path so you can lengthen the caption
+ """
+ print(f"Creating .txt files from filenames in {path}")
+ for idx, f in enumerate(glob.iglob(f"{path}/**", recursive=True)):
+ print(f"Creating {f}.txt")
+ if not os.path.isfile(f) or not os.path.splitext(f)[1] in ['.jpg', '.png', '.jpeg', '.webp', '.bmp']:
+ continue
+
+ path_without_filename = os.path.dirname(f)
+ base_name = os.path.splitext(os.path.basename(f))[0]
+ caption = os.path.splitext(base_name)[0].split("_")[0]
+ target = f"{path_without_filename}/{base_name}.txt"
+ print (f"Creating file: {target} from {f}")
+ with open(target, "w") as text_file:
+ text_file.write(caption)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--path", type=str, help="path to folder")
+ args = parser.parse_args()
+ create_txt_from_filename(args.path)
diff --git a/data/EveryDream/scripts/download_laion.py b/data/EveryDream/scripts/download_laion.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd3e04bc1714f280545c7a0e102b015cdefbcc35
--- /dev/null
+++ b/data/EveryDream/scripts/download_laion.py
@@ -0,0 +1,377 @@
+import sys
+import os
+import pandas as pd
+import pyarrow as pa
+import argparse
+import glob
+#import requests_async as requests
+import asyncio
+import aiohttp
+from typing import IO
+import aiofiles
+import re
+from colorama import Fore, Style
+from PIL import Image
+import io
+
+# can tweak these you feel like it, but shouldn't be needed
+unsafe_threshhold = 0.1 # higher values is more likely to be nsfw and will be skipped
+aesthetic_threshhold = 5 # higher is more aesthetic, note laion2B-aesthetic already >7?
+http_timeout = 10
+
+# dont touch
+downloaded_count = 0
+current_parquet_file_downloaded_count = 0
+logger_sp = None
+
+def get_base_prefix_compat():
+ """Get base/real prefix, or sys.prefix if there is none."""
+ return getattr(sys, "base_prefix", None) or getattr(sys, "real_prefix", None) or sys.prefix
+
+def in_virtualenv():
+ return get_base_prefix_compat() != sys.prefix
+
+def get_parser(**parser_kwargs):
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ do_not_download = True
+ parser.add_argument(
+ "--laion_dir",
+ type=str,
+ nargs="?",
+ const=True,
+ default="./laion",
+ help="directory with laion parquet files, default is ./laion",
+ )
+ parser.add_argument(
+ "--search_text",
+ type=str,
+ const=True,
+ nargs="?",
+ default=None,
+ help="csv of words with AND logic, ex \"photo,man,dog\"",
+ ),
+ parser.add_argument(
+ "--out_dir",
+ type=str,
+ nargs="?",
+ const=True,
+ default="./output",
+ help="directory to download files to, defaults is ./output",
+ ),
+ parser.add_argument(
+ "--log_dir",
+ type=str,
+ nargs="?",
+ const=True,
+ default=None,
+ help="directory for logs, if ommitted will not log, logs may be large!",
+ ),
+ parser.add_argument(
+ "--column",
+ type=str,
+ nargs="?",
+ const=True,
+ default="TEXT",
+ help="column to search for matches, defaults is 'TEXT', but you could use 'URL' if you wanted",
+ ),
+ parser.add_argument(
+ "--limit",
+ type=int,
+ nargs="?",
+ const=True,
+ default=100,
+ help="max number of matching images to download, warning: may be slightly imprecise due to concurrency and http errors, defaults is 100",
+ ),
+ parser.add_argument(
+ "--min_hw",
+ type=int,
+ nargs="?",
+ const=True,
+ default=512,
+ help="min height AND width of image to download, default is 512",
+ ),
+ parser.add_argument(
+ "--force",
+ type=bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="forces a full download of all images, even if no search is provided, USE CAUTION!",
+ ),
+ parser.add_argument(
+ "--parquet_skip",
+ type=int,
+ nargs="?",
+ const=True,
+ default=0,
+ help="skips the first n parquet files on disk, useful to resume",
+ ),
+ parser.add_argument(
+ "--verbose",
+ type=bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="additional logging of URL and TEXT prefiltering",
+ ),
+ parser.add_argument(
+ "--test",
+ action='store_const',
+ const=do_not_download,
+ default=not(do_not_download),
+ help="skips downloading, for checking filters, use with --verbose",
+ )
+
+ return parser
+
+def cleanup_text(file_name: str):
+ # TODO: can be improved
+
+ file_name = re.sub("", "", file_name)
+ file_name = re.sub("", "", file_name)
+ file_name = re.sub("", "", file_name)
+ file_name = file_name.replace('', '').replace("
", "")
+ file_name = file_name.replace('', '').replace("", "")
+ file_name = file_name.replace('', '').replace("", "")
+
+ file_name = re.sub(r'[^\x00-\x7F]+', '', file_name) # remove non-ascii
+
+ file_name = file_name.replace(' & ', ' and ').replace(' &', ' and').replace('& ', 'and ') \
+ .replace(" + ", " and ").replace(" +", " and").replace("+ ", "and ")
+
+ file_name = file_name.replace('\t', ' ').replace('\n', ' ').replace('\r', ' ')
+
+ file_name = file_name.replace('\"t"', ' ')
+
+ file_name = file_name.replace(" ♥ ","love").replace("♥ ","love ").replace(" ♥"," love") \
+ .replace("♥"," love ")
+
+ # remove bad chars
+ file_name = file_name.replace('\"', '').replace('?', '') \
+ .replace('<', '').replace('>', '').replace('/', '').replace('*', '') \
+ .replace('!', '').replace('#', '').replace('$', '').replace('%', '') \
+ .replace('^', '').replace('(', '').replace(')', '')
+
+ # replace with space
+ file_name = file_name.replace(':',' ').replace('|',' ').replace('@', '') \
+ .replace("/", " ").replace("\\'", "\'").replace("\\", " ").replace('\\', ' ') \
+ .replace('_', ' ').replace("=", " ")
+
+ # replace foreign chars
+ file_name = file_name.replace('é', 'e').replace('è', 'e').replace('ê', 'e') \
+ .replace('ë', 'e').replace('à', 'a').replace('â', 'a').replace('ä', 'a') \
+ .replace('ç', 'c').replace('ù', 'u').replace('û', 'u').replace('ü', 'u') \
+ .replace('ô', 'o').replace('ö', 'o').replace('ï', 'i').replace('î', 'i') \
+ .replace('í', 'i').replace('ì', 'i').replace('ñ', 'n').replace('ß', 'ss') \
+ .replace('á', 'a').replace('ã', 'a').replace('å', 'a').replace('æ', 'ae') \
+ .replace('œ', 'oe').replace('ø', 'o').replace('ð', 'd').replace('þ', 'th') \
+ .replace('ý', 'y').replace('ÿ', 'y').replace('ž', 'z').replace('ž', 'z') \
+ .replace('š', 's').replace('đ', 'd').replace('ď', 'd').replace('č', 'c') \
+ .replace('ć', 'c').replace('ř', 'r').replace('ŕ', 'r').replace('ľ', 'l') \
+ .replace('ĺ', 'l').replace('ť', 't').replace('ň', 'n').replace('ņ', 'n') \
+ .replace('ď', 'd').replace('Ď', 'D').replace('Ť', 'T').replace('Ň', 'N')
+
+ _MAX_LENGTH = 240
+ if (len(file_name) > _MAX_LENGTH):
+ file_name = file_name[:_MAX_LENGTH]
+
+ return file_name
+
+async def call_http(image_url: str, session: aiohttp.ClientSession):
+ #print(f"calling http and save to: {out_file_name}")
+ global downloaded_count
+ global http_timeout
+ global current_parquet_file_downloaded_count
+ try:
+ res = await session.request(method="GET", url=image_url, timeout=http_timeout)
+
+ if (res.status == 200):
+ return await res.content.read()
+ else:
+ print(f"{Fore.YELLOW}Failed to download image, HTTP response code: {res.status} for {Fore.LIGHTWHITE_EX}{image_url}{Style.RESET_ALL}")
+ downloaded_count -= 1
+ except Exception as e:
+ print(f"{Fore.YELLOW} *** Error downloading image: {Fore.LIGHTWHITE_EX}{image_url}{Fore.YELLOW}, ex: {str(e)}{Style.RESET_ALL}")
+ downloaded_count -= 1
+ pass
+ return None
+
+async def save_img(buffer: io.BytesIO, full_outpath: str):
+ try:
+ async with aiofiles.open(full_outpath, "wb") as f:
+ await f.write(buffer.getbuffer())
+ except Exception as e:
+ print(f"{Fore.RED} *** Unable to write to disk: {Fore.LIGHTWHITE_EX}{full_outpath}{Style.RESET_ALL}")
+ print(f"{Fore.RED} *** ex: {Fore.LIGHTWHITE_EX}{str(e)}{Style.RESET_ALL}")
+ pass
+
+def get_outpath_filename(data: any, full_outpath_noext: str, clean_text: str):
+ ext = "jpg"
+ full_outpath = None
+ buffer = None
+ try:
+ buffer = io.BytesIO(data)
+ image = Image.open(buffer)
+ ext = image.format.lower()
+
+ if (ext == "jpeg"):
+ ext = "jpg"
+
+ full_outpath = f"{full_outpath_noext}.{ext}"
+ except Exception as e:
+ print(f"{Fore.YELLOW} *** Possible corrupt image for text: {Fore.LIGHTWHITE_EX}{clean_text}{Style.RESET_ALL}")
+ print(f"{Fore.YELLOW} *** ex: {Fore.LIGHTWHITE_EX}{str(e)}{Style.RESET_ALL}")
+ pass
+ return full_outpath, buffer
+
+async def download_image(image_url: str, clean_text: str, full_outpath_noext: IO, session: aiohttp.ClientSession):
+ http_content = await call_http(image_url=image_url, session=session)
+
+ buffer = None
+
+ if (http_content is not None):
+ full_outpath, buffer = get_outpath_filename(data=http_content, full_outpath_noext=full_outpath_noext, clean_text=clean_text)
+
+ if buffer is not None:
+ global downloaded_count
+ downloaded_count += 1
+ await save_img(buffer, full_outpath)
+
+async def download_set_dict(opt, matches_dict: dict):
+ async with aiohttp.ClientSession() as session:
+ global downloaded_count
+ current_parquet_file_downloaded_count = 0
+ tasks = []
+ for row in matches_dict:
+ if downloaded_count < opt.limit:
+ current_parquet_file_downloaded_count += 1
+ pre_text=row["TEXT"]
+ image_url=row["URL"]
+
+ clean_text = cleanup_text(pre_text)
+
+ full_outpath_noext = os.path.join(opt.out_dir, clean_text)
+
+ if (opt.verbose):
+ print(f"{Fore.LIGHTGREEN_EX}***** Verbose log: ***** {Style.RESET_ALL}")
+ print(f"{Fore.LIGHTGREEN_EX} url: {image_url}{Style.RESET_ALL}")
+ print(f"{Fore.LIGHTGREEN_EX} text: {pre_text}{Style.RESET_ALL}")
+ print(f"{Fore.LIGHTGREEN_EX} captn: {clean_text}{Style.RESET_ALL}")
+
+ if any(glob.glob(full_outpath_noext + ".*")):
+ print(f"{Fore.YELLOW} already exists: {Fore.LIGHTWHITE_EX}{full_outpath_noext}{Fore.YELLOW}, skipping{Style.RESET_ALL}")
+ return
+
+ if not opt.test:
+ tasks.append(
+ download_image(image_url=image_url, clean_text=clean_text, full_outpath_noext=full_outpath_noext, session=session)
+ )
+ else:
+ current_parquet_file_downloaded_count += 1
+ downloaded_count += 1
+ if len(tasks) > 63:
+ await asyncio.gather(*tasks)
+ tasks = []
+ else:
+ print(f"{Fore.YELLOW} Limit reached: {opt.limit}, exiting...{Style.RESET_ALL}")
+ break
+ if not opt.test & len(tasks) > 0:
+ await asyncio.gather(*tasks)
+ print(f"{Fore.LIGHTBLUE_EX} Downloaded chunk of {current_parquet_file_downloaded_count} images{Style.RESET_ALL}")
+
+def query_parquet(df: pd.DataFrame, opt):
+ # TODO: efficiency, expression tree?
+ matches = df
+
+ matches = matches[(matches.HEIGHT > opt.min_hw) & (matches.WIDTH > opt.min_hw)]
+
+ if 'punsafe' in matches.columns:
+ matches = matches[(matches.punsafe > unsafe_threshhold)]
+
+ if ('aesthetic' in matches):
+ matches = matches[(matches.aesthetic > aesthetic_threshhold)]
+
+ if opt.search_text:
+ for word in opt.search_text.split(","):
+ matches = matches[matches[opt.column].str.contains(word, case=False)]
+
+ matches = matches[~matches["URL"].str.contains("dreamstime.com", case=False)] # watermarks
+ matches = matches[~matches["URL"].str.contains("alamy.com", case=False)] # watermarks
+ matches = matches[~matches["URL"].str.contains("123rf.com", case=False)] # watermarks
+ matches = matches[~matches["URL"].str.contains("colourbox.com", case=False)] # watermarks
+ matches = matches[~matches["URL"].str.contains("envato.com", case=False)] # watermarks
+ matches = matches[~matches["URL"].str.contains("stockfresh.com", case=False)] # watermarks
+ matches = matches[~matches["URL"].str.contains("depositphotos.com", case=False)] # watermarks
+ matches = matches[~matches["URL"].str.contains("istockphoto.com", case=False)] # watermarks
+
+ return matches
+
+async def download_laion_matches(opt):
+ print(f"{Fore.LIGHTBLUE_EX} Searching for {opt.search_text} in column: {opt.column} in {opt.laion_dir}/*.parquet{Style.RESET_ALL}")
+
+ for idx, file in enumerate(glob.iglob(os.path.join(opt.laion_dir, "*.parquet"))):
+ if idx < opt.parquet_skip:
+ print(f"{Fore.YELLOW} Skipping file {idx+1}/{opt.parquet_skip}: {file}{Style.RESET_ALL}")
+ continue
+
+ global downloaded_count
+ if downloaded_count < opt.limit:
+ print(f"{Fore.CYAN} reading file: {file}{Style.RESET_ALL}")
+
+ df = pd.read_parquet(file, engine="auto")
+ matches = query_parquet(df, opt)
+ # print(f"{Fore.CYAN} matches in current parquet file:{ Style.RESET_ALL}")
+ # print(matches)
+
+ match_dict = matches.to_dict('records') # TODO: pandas problems later in script... needs revisiting
+
+ await download_set_dict(opt, match_dict)
+ else:
+ print(f"{Fore.YELLOW}limit reached before reading next parquet file. idx: {idx}, filename: {file}{Style.RESET_ALL}")
+ break
+
+def isWindows():
+ return sys.platform.startswith('win')
+
+def ensure_path_exists(path: str):
+ if not os.path.exists(path):
+ print(f"{Fore.LIGHTBLUE_EX}creating path: {path}{Style.RESET_ALL}")
+ os.makedirs(path)
+
+if __name__ == '__main__':
+ print(f"{Fore.CYAN}Launching...{Style.RESET_ALL}")
+ inVenv = in_virtualenv()
+ print(f"is running in venv: {inVenv}")
+ #assert inVenv, "Error loading venv. Please run 'source everydream-venv/bin/activate', or in windows 'everydream-venv/bin/activate.bat'"
+
+ parser = get_parser()
+ opt = parser.parse_args()
+
+ print(f"Test only mode: {opt.test}")
+
+ if(opt.search_text is None and opt.force is False):
+ print(f"{Fore.YELLOW}** No search terms provided, exiting...")
+ print(f"** Use --force to bypass safety to dump entire DB{Style.RESET_ALL}")
+ sys.exit(2)
+
+ ensure_path_exists(opt.out_dir)
+
+ if (opt.laion_dir[-1] != "/" or opt.laion_dir[-1] != "\\"):
+ opt.laion_dir += "/"
+
+ if (isWindows()):
+ print("{Fore.CYAN}Windows detected, using asyncio.WindowsSelectorEventLoopPolicy{Style.RESET_ALL}")
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+ else:
+ print("{Fore.CYAN}Unix detected, using default asyncio event loop policy{Style.RESET_ALL}")
+
+ import time
+ s = time.perf_counter()
+
+ result = asyncio.run(download_laion_matches(opt))
+
+ elapsed = time.perf_counter() - s
+ print(f"{Fore.CYAN} **** Job Complete ****")
+ print(f" search_text: \"{opt.search_text}\", force: {opt.force}")
+ print(f" Total downloaded {downloaded_count} images")
+ print(f"{__file__} executed in {elapsed:0.2f} seconds.{Style.RESET_ALL}{Style.RESET_ALL}")
\ No newline at end of file
diff --git a/data/EveryDream/scripts/extract_video_frames.py b/data/EveryDream/scripts/extract_video_frames.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6e794d131c2f783ce3bcbe369ed03e84991a93d
--- /dev/null
+++ b/data/EveryDream/scripts/extract_video_frames.py
@@ -0,0 +1,90 @@
+import argparse
+from pathlib import Path
+import cv2
+
+def get_parser(**parser_kwargs):
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument(
+ "--vid_dir",
+ required=True,
+ type=str,
+ nargs="?",
+ const=True,
+ help="directory with videos to extract frames",
+ )
+ parser.add_argument(
+ "--out_dir",
+ type=str,
+ nargs="?",
+ const=True,
+ help="directory to put extracted images",
+ )
+ parser.add_argument(
+ "--format",
+ type=str,
+ nargs="?",
+ const=True,
+ default="png",
+ choices=["png", "jpg"],
+ help="image file format of the extracted frames",
+ )
+ parser.add_argument(
+ "--interval",
+ type=int,
+ nargs="?",
+ const=True,
+ default=10,
+ help="number of seconds between frame captures",
+ )
+ return parser
+
+def get_videos(input_dir):
+ for f in input_dir.iterdir():
+ file_path = Path(f)
+ if file_path.suffix in [".mp4", ".avi", ".mov", ".mpeg", ".mpg", ".mkv"]:
+ yield file_path
+
+def capture_frames(input_dir, output_dir):
+ print (f'Capturing video frames in {opt.interval} second intervals.\n')
+
+ for video_path in get_videos(input_dir):
+ print(f'Extracting {video_path}')
+ cap = cv2.VideoCapture(str(video_path))
+ if not cap.isOpened():
+ print(f'Could not open video')
+ continue
+
+ output = output_dir / video_path.stem
+ output.mkdir(exist_ok=True, parents=True)
+
+ current_frame = 0
+ count = 0
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ count_str = str(count).zfill(4)
+ cv2.imwrite(str(output / f'frame_{count_str}.{opt.format}'), frame)
+ current_frame += fps * opt.interval
+ cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
+ count += 1
+ else:
+ cap.release()
+ break
+
+ print(f'\nFinished extracting frames to {output_dir}\n')
+
+if __name__ == "__main__":
+ parser = get_parser()
+ opt = parser.parse_args()
+
+ if (not Path(opt.vid_dir).exists):
+ print("Video directory does not exist.")
+ exit(1)
+
+ if (opt.out_dir is None):
+ output = Path(opt.vid_dir) / "output"
+ print(f"No output directory specified, using default: {output}")
+ else:
+ output = Path(opt.out_dir)
+ capture_frames(Path(opt.vid_dir), output)
diff --git a/data/EveryDream/scripts/filename_replace.py b/data/EveryDream/scripts/filename_replace.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab0f32b0ef1b6f31f685fd51174641be3313e60
--- /dev/null
+++ b/data/EveryDream/scripts/filename_replace.py
@@ -0,0 +1,104 @@
+import sys
+import os
+import glob
+import argparse
+
+def get_parser(**parser_kwargs):
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument(
+ "--img_dir",
+ type=str,
+ nargs="?",
+ const=True,
+ default="input",
+ help="directory with images to be renamed",
+ ),
+ parser.add_argument(
+ "--find",
+ type=str,
+ nargs="?",
+ const=True,
+ default=None,
+ help="what strings to replace, in csv format, default: 'a man,a woman,a person'",
+ ),
+ parser.add_argument(
+ "--replace",
+ type=str,
+ nargs="?",
+ required=False,
+ const=True,
+ default=None,
+ help="string to replace with, ex. 'john doe'",
+ ),
+ parser.add_argument(
+ "--append_only",
+ type=str,
+ nargs="?",
+ required=False,
+ const=True,
+ default=None,
+ help="skips pronoun replace, adds a string at the end of the filename, use for 'by artist name' or 'in the style of somestyle'",
+ )
+
+ return parser
+
+def isWindows():
+ return sys.platform.startswith('win')
+
+def get_replace_list(opt):
+ if opt.find is None:
+ return ("a man", "a woman", "a person", \
+ "a girl", "a boy", \
+ "a young woman", "a young man", \
+ "a beautiful woman", "a handsome man", \
+ "a beautiful young woman", "a handsome young man",
+ )
+ else:
+ return opt.find.split(",")
+
+def get_outfile_name(infile, append):
+ new_filename = f"{os.path.splitext(infile)[0]} {append}{os.path.splitext(infile)[1]}"
+ return new_filename
+
+def rename_files(opt):
+ find_list = get_replace_list(opt)
+
+ for idx, file in enumerate(glob.iglob(f"{opt.img_dir}/*")):
+ print(file)
+
+ if os.path.splitext(file)[1] in (".jpg", ".png", ".jpeg", ".gif", ".bmp", ".webp"):
+ new_filename = file
+ if opt.append_only is not None:
+ new_filename = get_outfile_name(file, opt.append_only)
+ else:
+ for s in find_list:
+ if s in file:
+ new_filename = new_filename.replace(s, opt.replace)
+ try:
+ print(f"Renaming {file} to {new_filename}")
+ if os.path.exists(new_filename):
+ new_filename = new_filename = get_outfile_name(file, f"_{idx}")
+ print(f"filename already exists, appended '_n' to {new_filename}")
+
+ try:
+ os.rename(file, new_filename)
+ except Exception as e:
+ print(f"Error renaming file: {file}, skipping, error: {e}")
+ except Exception as e:
+ print(f"error opening file: {file}")
+ print(f"{e}")
+ raise e
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ opt = parser.parse_args()
+
+ import time
+
+ s = time.perf_counter()
+
+ rename_files(opt)
+
+ elapsed = time.perf_counter() - s
+ print(f"{__file__} executed in {elapsed:0.2f} seconds.")
\ No newline at end of file
diff --git a/data/EveryDream/scripts/image_caption_gui.py b/data/EveryDream/scripts/image_caption_gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8e04ca70fbdc10ff73dcc7f065ffe43bd8befb8
--- /dev/null
+++ b/data/EveryDream/scripts/image_caption_gui.py
@@ -0,0 +1,161 @@
+# Python GUI tool to manually caption images for machine learning.
+# A sidecar file is created for each image with the same name and a .txt extension.
+#
+# [control/command + o] to open a folder of images.
+# [page down] and [page up] to go to next and previous images. Hold shift to skip 10 images.
+# [shift + home] and [shift + end] to go to first and last images.
+# [shift + delete] to move the current image into a '_deleted' folder.
+# [escape] to exit the app.
+
+import sys
+import tkinter as tk
+from tkinter import filedialog
+from PIL import Image, ImageTk
+from pathlib import Path
+
+IMG_EXT = ["jpg", "jpeg", "png", "webp"]
+
+class CaptionedImage():
+ def __init__(self, image_path):
+ self.base_path = image_path.parent
+ self.path = image_path
+
+ def caption_path(self):
+ return self.base_path / (self.path.stem + '.txt')
+
+ def read_caption(self):
+ caption_path = self.caption_path()
+ if caption_path.exists():
+ with open(caption_path, 'r', encoding='utf-8', newline='') as f:
+ return f.read()
+ return ''
+
+ def write_caption(self, caption):
+ caption_path = self.caption_path()
+ with open(str(caption_path), 'w', encoding='utf-8', newline='') as f:
+ f.write(caption)
+
+ # sort
+ def __lt__(self, other):
+ return str(self.path).lower() < str(other.path).lower()
+
+class ImageView(tk.Frame):
+
+ def __init__(self, root):
+ tk.Frame.__init__(self, root)
+
+ self.root = root
+ self.base_path = None
+ self.images = []
+ self.index = 0
+
+ # image
+ self.image_frame = tk.Frame(self)
+ self.image_label = tk.Label(self.image_frame)
+ self.image_label.place(relx=0.5, rely=0.5, anchor=tk.CENTER)
+ self.image_frame.pack(expand=True, fill=tk.BOTH, side=tk.LEFT)
+
+ # caption field
+ self.caption_frame = tk.Frame(self)
+ self.caption_field = tk.Text(self.caption_frame, wrap="word", width=50)
+ self.caption_field.pack(expand=True, fill=tk.BOTH)
+ self.caption_frame.pack(fill=tk.Y, side=tk.RIGHT)
+
+ def open_folder(self):
+ dir = filedialog.askdirectory()
+ if not dir:
+ return
+ self.base_path = Path(dir)
+ if self.base_path is None:
+ return
+ self.images.clear()
+ for ext in IMG_EXT:
+ for file in self.base_path.glob(f'*.{ext}'):
+ self.images.append(CaptionedImage(file))
+ self.images.sort()
+ self.update_ui()
+
+ def store_caption(self):
+ txt = self.caption_field.get(1.0, tk.END)
+ txt = txt.replace('\r', '').replace('\n', '').strip()
+ self.images[self.index].write_caption(txt)
+
+ def set_index(self, index):
+ self.index = index % len(self.images)
+
+ def go_to_image(self, index):
+ if len(self.images) == 0:
+ return
+ self.store_caption()
+ self.set_index(index)
+ self.update_ui()
+
+ def next_image(self):
+ self.go_to_image(self.index + 1)
+
+ def prev_image(self):
+ self.go_to_image(self.index - 1)
+
+ # move current image to a "_deleted" folder
+ def delete_image(self):
+ if len(self.images) == 0:
+ return
+ img = self.images[self.index]
+
+ trash_path = self.base_path / '_deleted'
+ if not trash_path.exists():
+ trash_path.mkdir()
+ img.path.rename(trash_path / img.path.name)
+ caption_path = img.caption_path()
+ if caption_path.exists():
+ caption_path.rename(trash_path / caption_path.name)
+ del self.images[self.index]
+ self.set_index(self.index)
+ self.update_ui()
+
+ def update_ui(self):
+ if (len(self.images)) == 0:
+ self.filename.set('')
+ self.caption_field.delete(1.0, tk.END)
+ self.image_label.configure(image=None)
+ return
+ img = self.images[self.index]
+ # filename
+ title = self.images[self.index].path.name if len(self.images) > 0 else ''
+ self.root.title(title + f' ({self.index+1}/{len(self.images)})')
+ # caption
+ self.caption_field.delete(1.0, tk.END)
+ self.caption_field.insert(tk.END, img.read_caption())
+ # image
+ img = Image.open(self.images[self.index].path)
+
+ # scale the image to fit inside the frame
+ w = self.image_frame.winfo_width()
+ h = self.image_frame.winfo_height()
+ if img.width > w or img.height > h:
+ img.thumbnail((w, h))
+ photoImage = ImageTk.PhotoImage(img)
+ self.image_label.configure(image=photoImage)
+ self.image_label.image = photoImage
+
+if __name__=='__main__':
+ root = tk.Tk()
+ root.geometry('1200x800')
+ root.title('Image Captions')
+
+ if sys.platform == 'darwin':
+ root.bind('', lambda e: view.open_folder())
+ else:
+ root.bind('', lambda e: view.open_folder())
+ root.bind('', lambda e: root.destroy())
+ root.bind('', lambda e: view.prev_image())
+ root.bind('', lambda e: view.next_image())
+ root.bind('', lambda e: view.go_to_image(view.index - 10))
+ root.bind('', lambda e: view.go_to_image(view.index + 10))
+ root.bind('', lambda e: view.go_to_image(0))
+ root.bind('', lambda e: view.go_to_image(len(view.images) - 1))
+ root.bind('', lambda e: view.delete_image())
+
+ view = ImageView(root)
+ view.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
+ root.mainloop()
diff --git a/handler.py b/handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aef9a9f5e60f23d988f35b204d513f5477c5e93
--- /dev/null
+++ b/handler.py
@@ -0,0 +1,91 @@
+# this is the huggingface handler file
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
+from transformers import CLIPTextModel, CLIPTokenizer
+from omegaconf import OmegaConf
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid
+from animatediff.utils.util import load_weights
+from diffusers.utils.import_utils import is_xformers_available
+from typing import Any
+import torch
+from einops import rearrange
+import torchvision
+
+import numpy as np
+
+class EndpointHandler():
+ def __init__(self, model_path: str = "models/StableDiffusion/", inference_config_path: str = "configs/inference/inference-v3.yaml", motion_module: str = "models/Motion_Module/mm_sd_v15.ckpt"):
+
+ inference_config = OmegaConf.load(inference_config_path)
+ ### >>> create validation pipeline >>> ###
+ tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
+ unet = UNet3DConditionModel.from_pretrained_2d(model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
+
+ if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
+ else: assert False
+
+ self.pipeline = AnimationPipeline(
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
+ ).to("cuda")
+
+ self.pipeline = load_weights(
+ self.pipeline,
+ # motion module
+ motion_module_path = motion_module,
+ motion_module_lora_configs = [],
+ # image layers
+ dreambooth_model_path = "",
+ lora_model_path = "",
+ lora_alpha = 0.8,
+ ).to("cuda")
+
+ def initialize(self, context):
+ """
+ Initialize model. This will be called during model loading time
+ """
+
+
+
+ def preprocess(self, data):
+ """
+ preprocess will be called once for each request.
+ """
+
+ def __call__(self, prompt, negative_prompt, steps, guidance_scale):
+ """
+ __call__ method will be called once per request. This can be used to
+ run inference.
+ """
+ vids = self.pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=steps,
+ guidance_scale=guidance_scale,
+ width= 256,
+ height= 256,
+ video_length= 5,
+ ).videos
+
+ videos = rearrange(vids, "b c t h w -> t b c h w")
+ n_rows=6
+ fps=1
+ loop = True
+ rescale=False
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ # imageio.mimsave(path, outputs, fps=fps)
+
+ # return a gif file as bytes
+ return outputs
\ No newline at end of file
diff --git a/models/Motion_Module/2fbkcvmxtmp b/models/Motion_Module/2fbkcvmxtmp
new file mode 100644
index 0000000000000000000000000000000000000000..6b4867d1500cd54e27826de22d5c59d3c526d890
--- /dev/null
+++ b/models/Motion_Module/2fbkcvmxtmp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa7fd8a200a89031edd84487e2a757c5315460eca528fa70d4b3885c399bffd5
+size 1672133921
diff --git a/models/Motion_Module/gqdawx6utmp b/models/Motion_Module/gqdawx6utmp
new file mode 100644
index 0000000000000000000000000000000000000000..cb080b70a8a3058731e1eef275ff7ed7f72cfeae
--- /dev/null
+++ b/models/Motion_Module/gqdawx6utmp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf16ea656cb16124990c8e2c70a29c793f9841f3a2223073fac8bd89ebd9b69a
+size 1672133921
diff --git a/models/Motion_Module/impovhmrtmp b/models/Motion_Module/impovhmrtmp
new file mode 100644
index 0000000000000000000000000000000000000000..cb080b70a8a3058731e1eef275ff7ed7f72cfeae
--- /dev/null
+++ b/models/Motion_Module/impovhmrtmp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf16ea656cb16124990c8e2c70a29c793f9841f3a2223073fac8bd89ebd9b69a
+size 1672133921
diff --git a/models/Motion_Module/oz5u8b9jtmp b/models/Motion_Module/oz5u8b9jtmp
new file mode 100644
index 0000000000000000000000000000000000000000..6b4867d1500cd54e27826de22d5c59d3c526d890
--- /dev/null
+++ b/models/Motion_Module/oz5u8b9jtmp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa7fd8a200a89031edd84487e2a757c5315460eca528fa70d4b3885c399bffd5
+size 1672133921
diff --git a/models/StableDiffusion/.gitattributes b/models/StableDiffusion/.gitattributes
index 8a0edbd1e5709b501f85f4174545a01dc9497131..55d2855c5be698e0572b9f42af95f06bfd5fb002 100644
--- a/models/StableDiffusion/.gitattributes
+++ b/models/StableDiffusion/.gitattributes
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
+v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
diff --git a/models/StableDiffusion/README.md b/models/StableDiffusion/README.md
index dc018a6b8521cde25c0d5ddd3c627d67e6b67f7a..103bad8a83037abcdb9d24f2b21eba8792b33a5d 100644
--- a/models/StableDiffusion/README.md
+++ b/models/StableDiffusion/README.md
@@ -4,18 +4,17 @@ tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
-inference: false
+inference: true
extra_gated_prompt: |-
This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.
+ The CreativeML OpenRAIL License specifies:
+
1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content
2. CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license
3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)
- Please read the full license here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
-
- By clicking on "Access repository" below, you accept that your *contact information* (email address and username) can be shared with the model authors as well.
-
-extra_gated_fields:
- I have read the License and agree with its terms: checkbox
+ Please read the full license carefully here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
+
+extra_gated_heading: Please read the LICENSE to access this model
---
# Stable Diffusion v1-5 Model Card
@@ -31,10 +30,11 @@ You can use this both with the [🧨Diffusers library](https://github.com/huggin
### Diffusers
```py
from diffusers import StableDiffusionPipeline
+import torch
model_id = "runwayml/stable-diffusion-v1-5"
-pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16")
-pipe = pipe.to(device)
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
@@ -45,7 +45,10 @@ For more detailed instructions, use-cases and examples in JAX follow the instruc
### Original GitHub Repository
-1. Download the weights [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt)
+1. Download the weights
+ - [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt) - 4.27GB, ema-only weight. uses less VRAM - suitable for inference
+ - [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt) - 7.7GB, ema+non-ema weights. uses more VRAM - suitable for fine-tuning
+
2. Follow instructions [here](https://github.com/runwayml/stable-diffusion).
## Model Details
@@ -144,22 +147,23 @@ The model developers used the following dataset for training the model:
- LAION-2B (en) and subsets thereof (see next section)
**Training Procedure**
-Stable Diffusion v1-4 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
+Stable Diffusion v1-5 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
- Text prompts are encoded through a ViT-L/14 text-encoder.
- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
-We currently provide four checkpoints, which were trained as follows.
+Currently six Stable Diffusion checkpoints are provided, which were trained as follows.
- [`stable-diffusion-v1-1`](https://huggingface.co/CompVis/stable-diffusion-v1-1): 237,000 steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
194,000 steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
- [`stable-diffusion-v1-2`](https://huggingface.co/CompVis/stable-diffusion-v1-2): Resumed from `stable-diffusion-v1-1`.
515,000 steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
-- [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2`. 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
-- [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2`.225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
-- [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
+- [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2` - 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
+- [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2` - 225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
+- [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` - 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
+- [`stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) Resumed from `stable-diffusion-v1-5` - then 440,000 steps of inpainting training at resolution 512x512 on “laion-aesthetics v2 5+” and 10% dropping of the text-conditioning. For inpainting, the UNet has 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself) whose weights were zero-initialized after restoring the non-inpainting checkpoint. During training, we generate synthetic masks and in 25% mask everything.
- **Hardware:** 32 x 8 x A100 GPUs
- **Optimizer:** AdamW
diff --git a/models/StableDiffusion/model_index.json b/models/StableDiffusion/model_index.json
index 6866dceb3a870b077eb970ecf702ce4e1a83b934..daf7e2e2dfc64fb437a2b44525667111b00cb9fc 100644
--- a/models/StableDiffusion/model_index.json
+++ b/models/StableDiffusion/model_index.json
@@ -3,7 +3,7 @@
"_diffusers_version": "0.6.0",
"feature_extractor": [
"transformers",
- "CLIPFeatureExtractor"
+ "CLIPImageProcessor"
],
"safety_checker": [
"stable_diffusion",
diff --git a/models/StableDiffusion/safety_checker/config.json b/models/StableDiffusion/safety_checker/config.json
index a087b121ef108ea987e84c9c81f5759a81666f29..5dbd88952e7e521aa665e5052e6db7def3641d03 100644
--- a/models/StableDiffusion/safety_checker/config.json
+++ b/models/StableDiffusion/safety_checker/config.json
@@ -1,6 +1,6 @@
{
- "_commit_hash": null,
- "_name_or_path": "/home/patrick/stable-diffusion-v1-5/safety_checker",
+ "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
+ "_name_or_path": "CompVis/stable-diffusion-safety-checker",
"architectures": [
"StableDiffusionSafetyChecker"
],
@@ -88,7 +88,7 @@
"num_attention_heads": 12,
"num_hidden_layers": 12
},
- "torch_dtype": "float16",
+ "torch_dtype": "float32",
"transformers_version": null,
"vision_config": {
"_name_or_path": "",
diff --git a/models/StableDiffusion/scheduler/scheduler_config.json b/models/StableDiffusion/scheduler/scheduler_config.json
index 0b38643e5b34e2a9936c8a3d423df72add6edb35..82d05b0e688d7ea94675678646c427907419346e 100644
--- a/models/StableDiffusion/scheduler/scheduler_config.json
+++ b/models/StableDiffusion/scheduler/scheduler_config.json
@@ -10,4 +10,4 @@
"steps_offset": 1,
"trained_betas": null,
"clip_sample": false
-}
\ No newline at end of file
+}
diff --git a/models/StableDiffusion/text_encoder/config.json b/models/StableDiffusion/text_encoder/config.json
index 9bdb32947168c46b843b3282fe66d46f206e074c..4d3e873ab5086ad989f407abd50fdce66db8d657 100644
--- a/models/StableDiffusion/text_encoder/config.json
+++ b/models/StableDiffusion/text_encoder/config.json
@@ -1,5 +1,5 @@
{
- "_name_or_path": "/home/patrick/stable-diffusion-v1-5/text_encoder",
+ "_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
@@ -19,7 +19,7 @@
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
- "torch_dtype": "float16",
+ "torch_dtype": "float32",
"transformers_version": "4.22.0.dev0",
"vocab_size": 49408
}
diff --git a/models/StableDiffusion/tokenizer/tokenizer_config.json b/models/StableDiffusion/tokenizer/tokenizer_config.json
index da0513344464912b8a7b4b78dbf8d24ac533ddd5..5ba7bf706515bc60487ad0e1816b4929b82542d6 100644
--- a/models/StableDiffusion/tokenizer/tokenizer_config.json
+++ b/models/StableDiffusion/tokenizer/tokenizer_config.json
@@ -19,7 +19,7 @@
},
"errors": "replace",
"model_max_length": 77,
- "name_or_path": "/home/patrick/stable-diffusion-v1-5/tokenizer",
+ "name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json",
"tokenizer_class": "CLIPTokenizer",
diff --git a/models/StableDiffusion/unet/config.json b/models/StableDiffusion/unet/config.json
index 6d63242165378f518e00d09c66bd6b30142bbae4..1a02ee8abc93e840ffbcb2d68b66ccbcb74b3ab3 100644
--- a/models/StableDiffusion/unet/config.json
+++ b/models/StableDiffusion/unet/config.json
@@ -1,7 +1,6 @@
{
"_class_name": "UNet2DConditionModel",
"_diffusers_version": "0.6.0",
- "_name_or_path": "/home/patrick/stable-diffusion-v1-5/unet",
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [
diff --git a/models/StableDiffusion/v1-inference.yaml b/models/StableDiffusion/v1-inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4effe569e897369918625f9d8be5603a0e6a0d6
--- /dev/null
+++ b/models/StableDiffusion/v1-inference.yaml
@@ -0,0 +1,70 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/models/StableDiffusion/vae/config.json b/models/StableDiffusion/vae/config.json
index db5b7b2f538291f9c733005ef019db827c1d7a94..55d78924fee13e4220f24320127c5f16284e13b9 100644
--- a/models/StableDiffusion/vae/config.json
+++ b/models/StableDiffusion/vae/config.json
@@ -1,7 +1,6 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.6.0",
- "_name_or_path": "/home/patrick/stable-diffusion-v1-5/vae",
"act_fn": "silu",
"block_out_channels": [
128,
@@ -20,7 +19,7 @@
"layers_per_block": 2,
"norm_num_groups": 32,
"out_channels": 3,
- "sample_size": 256,
+ "sample_size": 512,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
diff --git a/pipeline.py b/pipeline.py
index cf815b1327cfb67f49bb2c8f3fb150332892ce1c..2dce4e4e6e700a15ade0ecf51dee8e7789309d47 100644
--- a/pipeline.py
+++ b/pipeline.py
@@ -27,17 +27,13 @@ from diffusers.utils import deprecate, logging, BaseOutput
from einops import rearrange
-from ..models.unet import UNet3DConditionModel
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class AnimationPipelineOutput(BaseOutput):
- videos: Union[torch.Tensor, np.ndarray]
-
-
class AnimationPipeline(DiffusionPipeline):
_optional_components = []