arnavkartikeya commited on
Commit
d3281f2
1 Parent(s): 20e0290

Upload 21 files

Browse files
Files changed (21) hide show
  1. CODEOWNERS +2 -0
  2. CODE_OF_CONDUCT.md +105 -0
  3. LICENSE.txt +12 -0
  4. README.md +116 -12
  5. SECURITY.md +7 -0
  6. app.py +143 -0
  7. cog.yaml +17 -0
  8. data.txt +300 -0
  9. demo.ipynb +0 -0
  10. eval_nocaps.py +118 -0
  11. eval_retrieval_video.py +250 -0
  12. imagecaptioning.py +93 -0
  13. predict.py +98 -0
  14. pretrain.py +173 -0
  15. requirements.txt +9 -0
  16. sample.png +0 -0
  17. train_caption.py +206 -0
  18. train_nlvr.py +213 -0
  19. train_retrieval.py +345 -0
  20. train_vqa.py +202 -0
  21. utils.py +278 -0
CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2
+ #ECCN:Open Source
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Salesforce Open Source Community Code of Conduct
2
+
3
+ ## About the Code of Conduct
4
+
5
+ Equality is a core value at Salesforce. We believe a diverse and inclusive
6
+ community fosters innovation and creativity, and are committed to building a
7
+ culture where everyone feels included.
8
+
9
+ Salesforce open-source projects are committed to providing a friendly, safe, and
10
+ welcoming environment for all, regardless of gender identity and expression,
11
+ sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12
+ race, age, religion, level of experience, education, socioeconomic status, or
13
+ other similar personal characteristics.
14
+
15
+ The goal of this code of conduct is to specify a baseline standard of behavior so
16
+ that people with different social values and communication styles can work
17
+ together effectively, productively, and respectfully in our open source community.
18
+ It also establishes a mechanism for reporting issues and resolving conflicts.
19
+
20
+ All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21
+ in a Salesforce open-source project may be reported by contacting the Salesforce
22
+ Open Source Conduct Committee at ossconduct@salesforce.com.
23
+
24
+ ## Our Pledge
25
+
26
+ In the interest of fostering an open and welcoming environment, we as
27
+ contributors and maintainers pledge to making participation in our project and
28
+ our community a harassment-free experience for everyone, regardless of gender
29
+ identity and expression, sexual orientation, disability, physical appearance,
30
+ body size, ethnicity, nationality, race, age, religion, level of experience, education,
31
+ socioeconomic status, or other similar personal characteristics.
32
+
33
+ ## Our Standards
34
+
35
+ Examples of behavior that contributes to creating a positive environment
36
+ include:
37
+
38
+ * Using welcoming and inclusive language
39
+ * Being respectful of differing viewpoints and experiences
40
+ * Gracefully accepting constructive criticism
41
+ * Focusing on what is best for the community
42
+ * Showing empathy toward other community members
43
+
44
+ Examples of unacceptable behavior by participants include:
45
+
46
+ * The use of sexualized language or imagery and unwelcome sexual attention or
47
+ advances
48
+ * Personal attacks, insulting/derogatory comments, or trolling
49
+ * Public or private harassment
50
+ * Publishing, or threatening to publish, others' private information—such as
51
+ a physical or electronic address—without explicit permission
52
+ * Other conduct which could reasonably be considered inappropriate in a
53
+ professional setting
54
+ * Advocating for or encouraging any of the above behaviors
55
+
56
+ ## Our Responsibilities
57
+
58
+ Project maintainers are responsible for clarifying the standards of acceptable
59
+ behavior and are expected to take appropriate and fair corrective action in
60
+ response to any instances of unacceptable behavior.
61
+
62
+ Project maintainers have the right and responsibility to remove, edit, or
63
+ reject comments, commits, code, wiki edits, issues, and other contributions
64
+ that are not aligned with this Code of Conduct, or to ban temporarily or
65
+ permanently any contributor for other behaviors that they deem inappropriate,
66
+ threatening, offensive, or harmful.
67
+
68
+ ## Scope
69
+
70
+ This Code of Conduct applies both within project spaces and in public spaces
71
+ when an individual is representing the project or its community. Examples of
72
+ representing a project or community include using an official project email
73
+ address, posting via an official social media account, or acting as an appointed
74
+ representative at an online or offline event. Representation of a project may be
75
+ further defined and clarified by project maintainers.
76
+
77
+ ## Enforcement
78
+
79
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
80
+ reported by contacting the Salesforce Open Source Conduct Committee
81
+ at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82
+ and will result in a response that is deemed necessary and appropriate to the
83
+ circumstances. The committee is obligated to maintain confidentiality with
84
+ regard to the reporter of an incident. Further details of specific enforcement
85
+ policies may be posted separately.
86
+
87
+ Project maintainers who do not follow or enforce the Code of Conduct in good
88
+ faith may face temporary or permanent repercussions as determined by other
89
+ members of the project's leadership and the Salesforce Open Source Conduct
90
+ Committee.
91
+
92
+ ## Attribution
93
+
94
+ This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95
+ version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96
+ It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97
+ [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98
+
99
+ This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100
+
101
+ [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102
+ [golang-coc]: https://golang.org/conduct
103
+ [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104
+ [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105
+ [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Salesforce.com, Inc.
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,12 +1,116 @@
1
- ---
2
- title: SCRIPTure Demo
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 3.5
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
2
+
3
+ ## Announcement: BLIP is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications!
4
+
5
+ <img src="BLIP.gif" width="700">
6
+
7
+ This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
8
+ To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
9
+
10
+ Catalog:
11
+ - [x] Inference demo
12
+ - [x] Pre-trained and finetuned checkpoints
13
+ - [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
14
+ - [x] Pre-training code
15
+ - [x] Zero-shot video-text retrieval
16
+ - [x] Download of bootstrapped pre-training datasets
17
+
18
+
19
+ ### Inference demo:
20
+ Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
21
+ The demo includes code for:
22
+ 1. Image captioning
23
+ 2. Open-ended visual question answering
24
+ 3. Multimodal / unimodal feature extraction
25
+ 4. Image-text matching
26
+
27
+ Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
28
+
29
+ Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
30
+
31
+ ### Pre-trained checkpoints:
32
+ Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
33
+ --- | :---: | :---: | :---:
34
+ 14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
35
+ 129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
36
+
37
+ ### Finetuned checkpoints:
38
+ Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
39
+ --- | :---: | :---: | :---:
40
+ Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
41
+ Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
42
+ Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
43
+ VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
44
+ NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
45
+
46
+
47
+ ### Image-Text Retrieval:
48
+ 1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
49
+ 2. To evaluate the finetuned BLIP model on COCO, run:
50
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
51
+ --config ./configs/retrieval_coco.yaml \
52
+ --output_dir output/retrieval_coco \
53
+ --evaluate</pre>
54
+ 3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
55
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
56
+ --config ./configs/retrieval_coco.yaml \
57
+ --output_dir output/retrieval_coco </pre>
58
+
59
+ ### Image-Text Captioning:
60
+ 1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
61
+ 2. To evaluate the finetuned BLIP model on COCO, run:
62
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
63
+ 3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
64
+ <pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
65
+ 4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
66
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
67
+
68
+ ### VQA:
69
+ 1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
70
+ 2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
71
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
72
+ 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
73
+ <pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
74
+
75
+ ### NLVR2:
76
+ 1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
77
+ 2. To evaluate the finetuned BLIP model, run
78
+ <pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
79
+ 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
80
+ <pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
81
+
82
+ ### Finetune with ViT-L:
83
+ In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
84
+
85
+ ### Pre-train:
86
+ 1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
87
+ 2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
88
+ 3. Pre-train the model using 8 A100 GPUs:
89
+ <pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
90
+
91
+ ### Zero-shot video-text retrieval:
92
+ 1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
93
+ 2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
94
+ 3. To perform zero-shot evaluation, run
95
+ <pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
96
+
97
+ ### Pre-training datasets download:
98
+ We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
99
+
100
+ Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
101
+ --- | :---: | :---: | :---:
102
+ CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
103
+ LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
104
+
105
+ ### Citation
106
+ If you find this code to be useful for your research, please consider citing.
107
+ <pre>
108
+ @inproceedings{li2022blip,
109
+ title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
110
+ author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
111
+ year={2022},
112
+ booktitle={ICML},
113
+ }</pre>
114
+
115
+ ### Acknowledgement
116
+ The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
SECURITY.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ## Security
2
+
3
+ Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4
+ as soon as it is discovered. This library limits its runtime dependencies in
5
+ order to reduce the total cost of ownership as much as can be, but all consumers
6
+ should remain vigilant and have their security stakeholders review all third-party
7
+ products (3PP) like this one and their dependencies.
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ import torch
4
+ from torchvision import transforms
5
+ import os
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.image as mpimg
9
+ import cohere
10
+ import gradio as gr
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ from models.blip import blip_decoder
15
+
16
+ image_size = 384
17
+ transform = transforms.Compose([
18
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
21
+ ])
22
+
23
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
24
+
25
+ model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
26
+ model.eval()
27
+ model = model.to(device)
28
+
29
+
30
+ from models.blip_vqa import blip_vqa
31
+
32
+ image_size_vq = 480
33
+ transform_vq = transforms.Compose([
34
+ transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
37
+ ])
38
+
39
+ model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
40
+
41
+ model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
42
+ model_vq.eval()
43
+ model_vq = model_vq.to(device)
44
+
45
+
46
+
47
+ def inference(raw_image, model_n, question="", strategy=""):
48
+ if model_n == 'Image Captioning':
49
+ image = transform(raw_image).unsqueeze(0).to(device)
50
+ with torch.no_grad():
51
+ if strategy == "Beam search":
52
+ caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
53
+ else:
54
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
55
+ return 'caption: '+caption[0]
56
+
57
+ else:
58
+ image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
59
+ with torch.no_grad():
60
+ answer = model_vq(image_vq, question, train=False, inference='generate')
61
+ return 'answer: '+answer[0]
62
+
63
+ #get caption for a single iamge
64
+ def get_caption(image_path):
65
+ img = Image.open(image_path)
66
+ return inference(img, "Image Captioning")[9:]
67
+
68
+ def display(image_path):
69
+ img = mpimg.imread(image_path)
70
+ img = Image.open(image_path)
71
+ plt.imshow(img)
72
+ print("Caption: " + get_caption(image_path))
73
+
74
+ #returns a dictionary with key -> img_path and value -> caption
75
+ def get_captions(img_directory, print_status=True):
76
+ #key is img path, value is the caption
77
+ captions = {}
78
+ length = 0
79
+ for file in os.listdir(img_directory):
80
+ length+=1
81
+ count = 0
82
+ for file in os.listdir(img_directory):
83
+ f = os.path.join(img_directory, file)
84
+ captions[f] = inference(Image.open(f), "Image Captioning")
85
+ if print_status:
86
+ print("Images complete:", str(count) + "/" + str(length))
87
+ print("Caption:", captions[f])
88
+ return captions
89
+ #writes dictionary to file, key and value seperated by ':'
90
+ def write_to_file(filename, caption_dict):
91
+ with open(filename, "w") as file:
92
+ for i in caption_dict:
93
+ file.write(i + ":" + caption_dict[i])
94
+ file.close()
95
+
96
+ # Text to Image API
97
+
98
+ import requests
99
+ import base64
100
+
101
+
102
+ #add max tokens a slider
103
+
104
+ def make_image_and_story(prompt):
105
+ if(prompt is None or prompt == ""):
106
+ host = 'https://dev.paint.cohere.ai/txt2img'
107
+ response = requests.post(host, json={'prompt': 'Random monster', 'n_samples' : 1, 'n_iter' : 1})
108
+
109
+ # decode image
110
+ imageBytes = base64.b64decode(response.json()['image']) #decode
111
+
112
+ # save to disk
113
+ f = open("sample.png", "wb")
114
+ f.write(imageBytes)
115
+ f.close()
116
+
117
+ caption = get_caption("sample.png")
118
+
119
+ co = cohere.Client('yRfs5ozta7DQtTF0duztE9bV7CNulvcxwuqJizhB')
120
+ response = co.generate(prompt=caption, model ='c0381280-2035-4042-a5a0-01f5800bd9c0-ft', max_tokens=80)
121
+
122
+ return Image.open("sample.png"), response.generations[0].text
123
+ else:
124
+ host = 'https://dev.paint.cohere.ai/txt2img'
125
+ response = requests.post(host, json={'prompt': prompt, 'n_samples' : 1, 'n_iter' : 1})
126
+
127
+ # decode image
128
+ imageBytes = base64.b64decode(response.json()['image']) #decode
129
+
130
+ # save to disk
131
+ f = open("sample.png", "wb")
132
+ f.write(imageBytes)
133
+ f.close()
134
+
135
+ caption = get_caption("sample.png")
136
+
137
+ co = cohere.Client('yRfs5ozta7DQtTF0duztE9bV7CNulvcxwuqJizhB')
138
+ response = co.generate(prompt=caption, model ='c0381280-2035-4042-a5a0-01f5800bd9c0-ft', max_tokens=80)
139
+
140
+ return Image.open("sample.png"), response.generations[0].text
141
+
142
+
143
+ gr.Interface(fn=make_image_and_story, inputs="text", outputs=["image","text"],title='Fantasy Creature Generator').launch();
cog.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.1"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "ipython==7.30.1"
10
+ - "torchvision==0.11.1"
11
+ - "torch==1.10.0"
12
+ - "timm==0.4.12"
13
+ - "transformers==4.15.0"
14
+ - "fairscale==0.4.4"
15
+ - "pycocoevalcap==1.2"
16
+
17
+ predict: "predict.py:Predictor"
data.txt ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The Adaro is half human; half fish much like that of a merfolk. They are said to live in the Sun and travel to earth by rainbows. Adaro's are said to travel around the oceans on waterspouts. Unlike merfolk, Adaro are not friendly they attack humans by shooting them with flying fish.
2
+ --SEPARATOR--
3
+ Brownies are invisible brown elves or household goblins who live in farmhouses and other country buildings within Scotland. While the members of the household are asleep they go about doing labours for the house owners. Brownies are protective creatures and become attached to the families if the family move the Brownie will move with them. If a Brownie is treated badly by the family or is offered payment the Brownie vanishes without trace. Children because of their innocent nature can only see Brownies, though this does not prevent the Brownies from helping adults.
4
+ --SEPARATOR--
5
+ The Calygreyhound is a creature from medieval heraldry; it has the body of a deer, the claws of an eagle on its forelegs and the hooves of an ox on its hind legs. The Calygreyhound is meant to symbolise swiftness.
6
+ --SEPARATOR--
7
+ The Chinese Fox looks like a typical fox but has a life span between 800 and 1000 years. The Chinese fox has special powers, when it strikes it tail on the ground it can ignite fires. The it is also said that the Chinese fox can see into the future. This animal is a polymorph (can change its shape at will). The Chinese fox often appears as an old man, young girl or an academic. The Chinese fox is a trickster and a sly animal it's thought to be a bad omen to see such an animal. Chinese foxes can be found around graveyards because the souls of the dead can relocate into the body of the fox.
8
+ --SEPARATOR--
9
+ The Crocotta is offspring of a wolf and a dog that can be found in Ethiopia. The Crocotta was said to be able to break anything with its teeth and can eat anything. It is now thought that this creature is in fact a Hyena as they are scavenging canine creatures that live in this part of the world.
10
+ --SEPARATOR--
11
+ A Dire wolf is a large wolf, which are usually portrayed as been much more vicious than a normal wolf and also more intelligent. Dire Wolves are usually black though dark greys are popular as well. There is a second description of a dire wolf that says that a dire wolf is a giant wolf that has been brought back from beyond the grave to hunt. Dire wolves tend to hunt in packs like there wolf relatives; a dire wolf can often be found leading packs of normal wolves.
12
+ --SEPARATOR--
13
+ Doppelganger means double walker which in normal terms is a duplicate of another person or creature. A doppelganger does not necessary have a blood relationship with its double. A doppelganger may therefore be either an astral projection or a ghost. Examples of doppelganger can be found in Robert Louis Stevenson's "Strange case of Dr Jekyll and Mr Hyde".
14
+ --SEPARATOR--
15
+ A Gorgon is a terrifying creature, they look similar to a human apart from its legs are replaced by a snakes tail, and that the hair of these creatures is made up of snakes. The gaze of a gorgon can turn its victim to stone if only it you look the gorgon in their eyes.
16
+
17
+ The Gorgon comes from Greek mythology; there were three the first was Medusa who was a mortal gorgon, who was slain by Perseus who killed her by using the reflection that she cast onto his shield. Medusa had two sisters Sthens and Euryule who were immortal.
18
+ --SEPARATOR--
19
+ Gremlins are spirits of tools and machinery. They are thought to be responsible for mishaps and breakdowns with tools and equipment. Each house has a Gremlin which entered the house as an occupant of a household appliance. Originally Gremlins were friendly towards mankind and helped engineers and inventors build things but when these people took all the credit the Gremlins were insulted and from that point on they work against us.
20
+ --SEPARATOR--
21
+ The Gulon is mythical creature from Scandinavia, it has the front half of a lion the rear half of a hyena and the bushy tail of a fox. The creature is said to be vicious and has the sharp claws of the lion to attack with. It has been used as a symbol of gluttony.
22
+ --SEPARATOR--
23
+ The Haunt are dog like creatures, they are usually black. Males have two deep set red eyes and female's two dark blue eyes. They live in caves deep in mountains and can live with little food and water. Every full moon the males come down into nearby cities and towns to hunt. They have Poisonous saliva and claws that can drip with poison.
24
+
25
+ The Haunt
26
+ These creatures of the darkness come down from the mountain,
27
+ Hunting for prey,
28
+ For they need to feed I hear them growling,
29
+ Scratching at the door,
30
+ With their powerful legs they chase you till death,
31
+ And a powerful jaw crushes your bones,
32
+ Fangs and claws drip with a poison touch,
33
+ They hide in the darkness,
34
+ Ebony fur As black as night,
35
+ Eyes as Red as fire,
36
+ They lurk the streets,
37
+ And stalk the lost But then dawn breaks and the sun rises,
38
+ They bound to the distance, back to the caves.
39
+
40
+ Poem by Ben
41
+ --SEPARATOR--
42
+ Hellhounds look like a well-built canine with rusty red fur. Hellhounds have glowing red eyes which a used to terrify prey their eyes also enable them to see in the dark. Hellhounds stand about four and a half foot in height. Hellhounds have the ability to breathe fire. They are very aggressive beasts and are expert hunters. They have a pack structure like wolves.
43
+ --SEPARATOR--
44
+ Howlers are four legged dog like creatures. They are furless but have tough skin. They are equipped with long claws, and are covered in hard spikes along its back. They live in dark places and are often associated with evil. They are intelligent creatures that haunt in packs like wolves. There tactic for taking down prey is to charge it as a pack then back away then charge it again, they do this until the prey has succumbed.
45
+ --SEPARATOR--
46
+ Jack Frost is an elfish creature who personifies crisp, cold weather. Jack is said to leave patterns in the autumn leaves and the patterns in the frost that are left on windows. It is thought that Jack Frost comes from Norse mythology as Jokul who was the cause of icicles.
47
+ --SEPARATOR--
48
+ Lamia's are a cross between an attractive human and a lion. The body part is lion, with a human torso rising from where the lions head would be, they look not unlike centaurs but lion instead of horse. Lamias are evil and cruel creatures that get pleasure in causing pain and suffering. Lamias have the magical ability to drain knowledge from those that they attack.
49
+ --SEPARATOR--
50
+ A Leprechaun is a small sprite that lives in farmhouses or wine cellars. Like Brownies they aid humans and accomplish small labours for them, they ask humans for supplies and furniture in return they give objects that bring luck and fortune. Leprechauns are described as merry little fellows that dress in old-fashioned green clothes with buckled shoes and they wear a red cap. They are known as fairy cobblers, as they make shoes for other elves. They never make a pair of shoes they only make one. Popular belief is that a leprechaun possesses a treasure, which a human can obtain if they succeed in capturing one, which is very difficult. Leprechauns are mainly found in Irish folklore but do appear in other countries.
51
+ --SEPARATOR--
52
+ Pixies or Piskies are small people who live on the dons and moors of Cornwall and Devon in the South of England. According to the myths Pixies were originally druids who resisted Christianity that they resisted the influences and the Pixies grew smaller. Another myth tells us they were a race of people who were not good enough to go to heaven but nor were they bad enough to go to Hell and were doomed to walk the earth forever. Pixies are known to steal horses and make nocturnal trips on them over the moors. Pixies like to trick humans such as throwing objects around the house. Pixies are hardworking; they work in the fields the entire night to earn some food.
53
+ --SEPARATOR--
54
+ Snotlings are the smallest of the greened skinned races. Orcs & Goblins use them as slaves for simple tasks as they are not intelligent creatures at about the same size of a Gnome. They are not a threat on their own to other creatures. They realise this so they form gangs that again can attack a target with numbers.
55
+ --SEPARATOR--
56
+ The Vegetable lamb or otherwise known as Tartary, Barbary Lamb or Barmotez comes from Hebrew legends. The Vegetable lamb is a lamb like creature that is grown from a tree, it remains attached to the tree and they eat the foliage in the area around the tree once there is no food left they die of starvation. The Vegetable lamb is a delicacy the meat from it is meant to taste of fish, its blood that is like honey. Its bones were used in rituals to give humans the power of foresight (predicting the future).
57
+ --SEPARATOR--
58
+ The Ahuizotl originates from Central America; it is a creature that is half human, half monkey, with a hand at the end of its tail. The Ahuizotl was actually a creature that lived in the water; it snatched people which approached close to the water's edge, or sometimes attacked fishermen on their boats. The Ahuizotl was a much-feared creature due to its fondness of eating human flesh.
59
+ --SEPARATOR--
60
+ The Basilisk is the mythical king of the serpents. The basilisk is born from a spherical yolkless egg, which was laid by a seven-year rooster and hatched by a toad. According to the legends there are two species of Basilisk. The first is a creature that burns everything it approaches the second kind can kill every living thing with a mere glance. Both species are so evil that their breath wilts vegetation and can crack stones.
61
+
62
+ A Basilisk is a highly poisonous creature. It was though the only way to kill a Basilisk is by holding a mirror In front of its eyes, the moment the creature sees itself in the mirror it dies of sheer fright. The basilisk had natural enemies the weasel was immune to the Basilisks glance, and if the Basilisk should here a cock crow it would be killed instantly.
63
+ --SEPARATOR--
64
+ The Bayard is a talking horse; it was given to the four sons of Aymon by Charlemagne. The Horse had the ability to elongate its back to accommodate all four of the sons; the Bayard also had great speed.
65
+ --SEPARATOR--
66
+ The Bicorne is a mythical creature is part Panther part cow that is fat from overeating. The Bicorne feeds on virtuous husbands. From the creatures description it is saying there are lots of men that a virtuous because this creature is well fed. The Bicorne's counterpart the Chichevache is thin this feeds on virtuous wives suggesting that the women were not very faithful to their husbands.
67
+ --SEPARATOR--
68
+ The Bunyip is a mythical beast from Australia. It originates from the Aboriginal culture. It is a vicious amphibious creature; it is said to have the appearance of a large seal. IT is greatly feared as it preys on humans, particularly the more tender flesh of women and children.
69
+ --SEPARATOR--
70
+ The Camelopard is the spotted offspring of the mother of a camel and father of a leopard. It is though now that this creature the Camelopard is in fact the Giraffe.
71
+ --SEPARATOR--
72
+ The Catoblepas or Catoblepe originates from Ethiopia, it is said to live near the spring that is the source of the River Nile. Its name comes from Greek which translates to "that which looks downwards". The Catoblepas is four legged beast with a bull like appearance, it has a long mane that falls across its head, and the beast's body is covered in hard scales. The Catoblepas seems a lethargic creature standing around grazing but any living creature that meets its eyes falls dead on the spot. It is said to have a poisonous breath that it can expel which causes loss of sight, voice and leads to convulsions and eventually death.
73
+ --SEPARATOR--
74
+ The Centicore is a four legged beast with two long straight horns that are extremely sharp. It uses these horns much like spears. The Centicore is horse like in appearance but with a chest of a lion and ears that grow in its mouth. The Centicore horns can move and usually has only one facing forward the other is laid across its back.
75
+ --SEPARATOR--
76
+ The Chichevache is a cow with human face, this mythical beast is thin, and this creature fed on obedient and faithful wives. The Chichevache is thin this feeds on virtuous wives suggesting that the women were not very faithful to their husbands The counterpart Bicorne is fat which it is saying there are lots of men that a virtuous because this creature is well fed.
77
+ --SEPARATOR--
78
+ A Girallon is a close relation to the gorilla they are large creatures that are highly aggressive and territorial. The main difference between a Girallon and a gorilla is that the Girallon has four arms, to the gorillas two. Girallons are white or pale grey in colour, about 8ft in height. Girallons attack anything that enter territory, they tend to live in small packs, with a dominant male been the packs leader.
79
+ --SEPARATOR--
80
+ The Humbata comes from Sumerian Epic, where Gilgamesh encounters this creature. The Humbata has the head and horns of a bull and the paws and the body of a lion and the talons of a vulture.
81
+ --SEPARATOR--
82
+ The Icthyocentaur has the torso, arms and head of a human, the tail of a dolphin and the forelegs of a horse or lion. This is the aquatic version of a centaur.
83
+ --SEPARATOR--
84
+ The Karakadon is a bull like creature that has a single curved horn growing from its forehead, much like that of a unicorn. The creature strikes terror into other creatures with its breath and thunderous bellow.
85
+ --SEPARATOR--
86
+ Minotaur's are powerful beasts that are half bull, half-human. They are savage creatures, which can use simple weapons such as clubs and large axes. They stand around 8 foot in height their furs are often dark brown in colour but can be black and lighter browns. They are often said to have large bellows. The origin of the Minotaur lie within Greek mythology where a Minotaur was kept in a labyrinth at Knossos until Thesus killed it. They have been featured in many fantasy novels such as the Chronicles of Narnia and also appear in the fantasy role-play games.
87
+ --SEPARATOR--
88
+ Ogres stand twice as tall as a human and are strongly built with large muscles. Ogres are not highly intelligent creatures but they are not stupid and nor are they evil creatures. They tend to rely on brawn rather than brains. Ogres are good fighters and other races try to recruit them as mercenaries for armies. Ogres don't really care for whom they fight for as long as the moneys good and that they get a good fight.
89
+ --SEPARATOR--
90
+ The Pegasus comes from Greek mythology it is a winged horse. The Pegasus classically has a white hide. Poseidon with Medusa brought Pegasus forth, when Persus cut off Medusa head the Pegasus flew out. Pegasus's are known to be stubborn but intelligent creatures.
91
+ --SEPARATOR--
92
+ The Sasquatch is better known as Bigfoot, Sasquatch comes from North America. The meaning of the name which comes from the native American word that means "Hairy Man". The Sasquatch is a large ape like creature that is thought to measure up to about 7 foot in height. It has long ape like arms with a flattened nose and is covered in a thick matt of hair. It is though that these creatures live in caves in the vast forests of North America. The first reported sighting for the beast was in 1811, since then there have been many hundreds of sightings. The Native Americans had seen this beast before that date. Many expeditions had set out to search for Sasquatch but has never found evidence for its existence.
93
+ --SEPARATOR--
94
+ Trolls are monsters that come from Scandinavian myths and Nordic fantasy. Trolls are called Trows in myths from the Shetlands. Trolls are not especially intelligent and are often associated with an element. The most common description of Trolls is a tough ugly creature with tough rock like skin. Trolls have the powers to regenerate even if hacked apart. The only way to stop a Troll regenerating is burning the monster. Trolls diet is unusual, as they will eat anything including metal, bone, wood and rocks. The stomachs of trolls contain very powerful digestive acids. This has led Trolls to an unpleasant form of attack of vomiting over their target; this is an extremely painful attack. Sometimes Trolls use basic hand weapons such as clubs or large stones.
95
+ --SEPARATOR--
96
+ A Unicorn is not unlike a war-horse; it has heavy cloven hooves, the main difference in appearance is that it has a single central horn on its forehead. Unicorns are intelligent creatures. Unicorns like to use their horn as a lance when charging the enemy. Unicorns are a magical creature, which means they can dispel magic that is targeted at them. The horn of a Unicorn is said to have great healing powers. They are often seen to be pure white or grey in colour, though have been known to show the same variation as normal horses. In some mythology unicorns can only be tamed by females that are pure of heart.
97
+ --SEPARATOR--
98
+ The Yale is a four legged beast from Ethiopia and India its colour is a tawny brown or black. It is about the size of a horse; it looks much like a deer but has the lower jaws of a boar with its tusks. The Yale has movable horns that it can control.
99
+
100
+ The Yale can move a single horn forward to use as in a lance like fashion, the other horn moves out of the way to protect it. If one horn is damaged in a fight it moves the other horn in to place to resume the attack.
101
+
102
+ The Yale is seen in British heraldry it is ones of the Queens beasts and is featured on the arms of Christ's college, Cambridge University.
103
+ --SEPARATOR--
104
+ Anima is a term from Jungian psychology referring to a male's collective unconscious which represents the feminine aspect. Incarnations of this include the males own mother. An example of this is seen in Final Fantasy X where an Aeon (summoned creature) called Anima which represents the spirits of Seymour's mother whom had become a Fayth. If the creature is representing females male side is called Animus.
105
+ --SEPARATOR--
106
+ The Chimera originates from Greek mythology, and is characterised as a creature with the head of a lion, body of a she-goat and the tail of a dragon. The Chimera can also depicted, although not in the Greek mythology to have three heads one of that of a lion, a ram and a dragon. The Chimera occasionally is depicted as a creature with wings that are of a great eagle. In Greek mythology the Chimera is the child of Typhon and Echidna.
107
+ --SEPARATOR--
108
+ The Cockatrice is a snake like creature, which has a pair of great wings that are seen to come from that of a great eagle or that are leathery wings like a dragons. Characteristics of a Cockatrice are that it has glowing red eyes with black pupils. Cockatrice has a magical gaze that it can petrify an attacker to stone.
109
+ --SEPARATOR--
110
+ The Cyclops is from Greek mythology and was originally storm gods Brontes (Thunder), Sterops (Lightening), and Arges (Thunderbolt). They are noted for only having one eye in the centre of their foreheads and been giants many times the height of man. They were great smiths and helped the Greeks. It was said that they fashioned Zeus's lightning bolt. Eventually the Cyclops was seen as a cannibalistic and brutish they were feared and shunned.
111
+ --SEPARATOR--
112
+ Giants are oversized humanoids that are featured in much folklore, and appear in such mythology as Norse, and Greek. Giants are often seen as earth's elder race. A Giant's characteristics traditionally make then to be brutish and hostile. They are seen to feed on cattle, sheep and human or anything living that is smaller in size. Giants normally are seen to be solitary although the Greek titans formed communities. A well-known giant is Blunderbore who lived in a huge castle in a cloud; he was defeated by a boy called Jack (from the fairy tale Jack & the beanstalk).
113
+ --SEPARATOR--
114
+ Stories of giant ants that range in size from fox size to the size of elephants. The stories were reported to have come from India. The Giant Ants guard earth and the gold they dig out of it. Giant ants can detect the approach of inbound creatures and they swarm out of there burrows they catch and devour any intruders. It is thought though today that these beasts were not but Marmots which translate from Persian to mountain ant.
115
+ --SEPARATOR--
116
+ Giant crabs are seen as large monsters in far of countries they in fact are oversized crabs. Although the origins for these giant crabs may not be so mythological or made up as crabs have a rare ability is that they don't really die from old age if protected and in the right environment they can keep growing. Large crabs have been found that are over a metre in width (which in its self is huge compared to ones you see on the beach about 2-8cms in width). In times when we humans didn't pollute the sea as much and these crabs had all the time to grow, you never know.
117
+ --SEPARATOR--
118
+ Giant scorpions are oversized scorpions in varying proportions. The largest scorpions today are about 25cm in length that is big for an arachnid. These oversized creatures are a terrifying, with their two big pincers and a deadly stinging tail and their ability to scurry up walls. These creatures would never have existed because Insects and Arachnids are limited to size by their biology and how they respire.
119
+ --SEPARATOR--
120
+ Giant spiders are an oversized variant of the common spiders that we all know about, spiders have always had that terror value as people have seen them to be creepy and that they produce large webs to catch their prey. Anything that gets caught in the sticky web the spider comes and grabs them and drains them of their fluid. Spider characteristics are that they have eight legs, they can walk on most surfaces including upside down on the ceiling. They produce webs and that sometimes have a poisonous bite. Oh and spiders can jump.
121
+ --SEPARATOR--
122
+ The Griffin or Griffon is a legendary creature, that's head, beak and wings come from that of a great eagle, the body is that of a lion or a tiger and on occasion it has a tail of a scorpion the front feet often in talons. A Griffin is a fierce and deadly creature tearing its prey apart with its huge beak while pinning its prey to the ground with one of its talons or paws. The origin of the Griffin is unknown but its though to be somewhere in the Middle East as this is where early paintings and sculptures by the ancient Assyrians, Babylonians and Persians. In more recent times the Griffin is often seen on churches as gargoyles and on shield heraldry. The Griffon is the symbol shown on Vauxhall cars in Great Britain.
123
+ --SEPARATOR--
124
+ The Hippocerf is a mythical creature, which is half horse and half deer. This creature represents indecision.
125
+ --SEPARATOR--
126
+ The Hippogrif is legendary animal that is half Griffin and half horse. The forequarters and head is that of the griffin that would be the father's half and hindquarters would be of its mother a filly. The Hippogrif is a ferocious creature. The Hippogrif is often found in ancient Greek paintings but it was more largely seen in medieval times.
127
+ --SEPARATOR--
128
+ A Lammasu has the body of a lion the wings of a giant eagle and the face of a human. Lammasu's are said to be noble creatures that look after those that are good. They prey on those creatures that are evil. They are strong creatures they can easily take down larger creatures with its lions claws.
129
+ --SEPARATOR--
130
+ The Manticore is a vicious creature which comes from Asia. The Manticore is a lion like creature that's head has some human likeness; it has a tail which is sometimes depicted as a tail of a scorpion or a tail that fires poisonous darts. The Manticore sometimes is depicted having leathery wings. In Asia where the creature stalks through the forests in search of humans, it attacks first by using poison to render the victim immobile. After been immobilized the victim then can be devoured; bones, clothes, possessions the lot, so the Manticore leaves no trace. There are other descriptions of Manticore that don't fit the above patterns, Manticore's can move around in prides and they kill for the sake of killing and leave there prey behind to be found.
131
+ --SEPARATOR--
132
+ The Sphinx is seen in two cultures, the Egyptians used the Sphinx as a statue, which had the body of the lion and the head of a human, and sometimes the sphinx had wings. Sphinxes were built to honour their kings who the head was normally a portrait of. The Greeks who visited Egypt much later gave the name Sphinx to these structures. The Greek sphinx was a creature of death and destruction and bad luck. The Sphinx was portrayed as a female creature as a winged lion with a feminine head, and sometimes a snake tail. The Sphinx was the offspring of Typhon and Echidna. The Sphinx is sometimes seen as the guardian of temple entrances.
133
+ --SEPARATOR--
134
+ A Taurus is a ferocious creature that is made up of a bull that has wings, they tend to be brightly coloured. A Taurus appears in the zodiac star signs though there it does not feature wings.
135
+ --SEPARATOR--
136
+ Titans are huge human creatures that are about twenty five foot in height. They are powerful creatures that are intelligent. They dress in flowing clothing. Titans have there own cultures and gods. Titans originate from Greek mythology.
137
+ --SEPARATOR--
138
+ Black Orcs are physically biggest and strongest of the related green skins species their skin is much darker green in skin tones than that of Orcs, Goblins and Hobgoblins. Black Orcs are more intelligent as well are often seen to be leading the Orc tribes, they tend to be better armed than other Orcs so they wear heavier armour and carry around larger weapons.
139
+
140
+ Black Orcs tend to be much better organised and infighting within a tribe does not succumb to petty squabbles. Other Orcs tend not to squabble when black Orcs are around either because they don't want to be picked on by their bigger cousins.
141
+ --SEPARATOR--
142
+ Centaurs come from Greek mythology and were said to have inhabited the region of Magnesia and Mount Pelion in Thessaly, the Foloi oak forest in Elis, and the Malean peninsula in southern Laconia. These creatures are part human part horse. They have the torso and head of a human with the body of a horse. Centaurs followed Dionysus the wine god that is why centaurs are known for drunkenness and carrying off young maidens. Though female centaurs, called centaurides or centauresses, are not mentioned in early Greek literature and art, they do appear occasionally in later antiquity.
143
+ --SEPARATOR--
144
+ The Dark Elves or Drows are the same as there high elf cousins a tall noble looking race with pale skin and are slender but unlike High Elves which are masters of light magic's, dark elves practice dark sorcery, and have become very adept at it. Dark Elves live in great cities made of dark stone. Dark Elves are constantly at war with their High Elf kin. They use masterly crafted weapons but they carry repeating crossbows, which seem to be unique to them. Dark Elf females that follow the path of sorcery are called Witch Elves and are a warrior sisterhood. Dark Elves always seem to where dark colours such as blacks, deep purples and dark blues.
145
+ --SEPARATOR--
146
+ Dryads come from Greek Mythology and are female spirits of nature, which preside over the forests. A dryad is born with a certain tree species and a particular tree, which she watches over. If the tree is destroyed then the Dryad perishes alongside the tree. Dryads punish mortals that somehow damage the trees.
147
+ --SEPARATOR--
148
+ Dwarves are a short, Muscular race, that the males and females look alike as both have beards, and in dwarfs beards mean a lot. Dwarves live deep beneath the mountains and have mined themselves vast strongholds beneath theses mountain peaks. Dwarves are immensely strong and resilient they have broad hands and feet. Dwarves are known to be stubborn and un-forgetting. Dwarves respect the following things; Age, Wealth and skill.
149
+
150
+ Dwarves favour axes as a weapon but compared with other races dwarfs have embraced technology and use a variety of weapons that are more advanced than others of the ancient races. Dwarves will use pistols, cannons, flame cannons. Dwarfs like the Elves take pride in their work and their axes show immense craftsmanship. Dwarves have a hate of Goblins & Orcs who have raided there strong holds many a time and took away their wealth. Dwarves also distrust Elves due to an ancient war that was fought between them.
151
+ --SEPARATOR--
152
+ Fairies come in all sorts of guises some look like Goblins, a typical fairy is a tiny human like creature that is 12 inches high, and has delicate wings on its back. These creatures tend to be friendly with humans but humans rarely ever see them. Fairies tend to have specific purposes like the tooth fairy that goes round collecting teeth that are placed under pillows and swaps them for some money. Sometimes fairies are guides to lost people.
153
+ --SEPARATOR--
154
+ The Genie comes from Arabian folklore and is a supernatural fiery creature. They can be both good and evil creatures; evil ones are said lead humans astray. In popular western culture Genies are often seen as been concealed with old lamps, which when rubbed genie appears out of them, the reason given is that they have been trapped inside the lamp by an evil sorcerer. This description comes from the western translation of "The Book of One Thousand and One Nights". Genies often grant wishes as well to the person that frees them from the lamp, commonly this is three wishes that the Genie can grant.
155
+ --SEPARATOR--
156
+ Gnomes are a very small humanoid race, which are not physically different to humans other than they are only about 1 to 2 feet high. Gnomes are classically seen as protectors of the natural world, often they are portrayed as little men with beards who tend to the garden. Gnomes are also seen as adventures wanting to discover the world they often join parties of adventuring heroes. Gnomes can prove useful due to their height and sometimes turn to a life of crime.
157
+ --SEPARATOR--
158
+ Goblins vary in size but are smaller than Orcs or Hobgoblins. They are a green skinned race, which have pointy ears. Goblins have sharp pointy teeth. They look thin and scrawny. Goblins sometimes are seen as more intelligent than their cousin Orcs, but are often bullied by their bigger and tougher relations. Goblins form tribes, in which the biggest and hardest goblin is the leader. Goblins are often seen in company of Orc tribes. Goblins use a varying array of weapons well anything they can lay their hands on. Goblins are seen to ride wolves into battle, in which case these tend to be the vanguard of an Orc & Goblin army.
159
+ --SEPARATOR--
160
+ Elves are a tall, slim and regal built race. They are seen as a noble race, and tend to be beautiful or handsome in appearance. Elves are a paled skinned race. Elves are strong and agile in comparison with humans and are often seen as more intelligent and wiser. Elves have a longer life span than humans do. Elves build refined weapons that are seen as master craftsmanship by other races. They use Swords, bows and lances mainly they don't like using crossbows or gunpowder weapons.
161
+ --SEPARATOR--
162
+ Hobgoblins are close relations to Orcs and Goblins. They are distinctly taller than Goblins but not as muscular as Orcs. Hobgoblins are renowned for been cowardly and sneaky. They have green skin like their cousins and large k9 teeth with hooked noses and pointy ears. Hobgoblins use a varying array of weapons and like goblins they ride large wolves into battle.
163
+ --SEPARATOR--
164
+ Merfolk are a mythical race which have grown from sailors tales, told across the world. Merfolk are basically human but have a fish's tail, they are able to breath underwater. Merfolk are said to be beautiful and that many Mermaids (female Merfolk) have fallen in love with sailors that they have rescued. Merfolk tend to carry Tridents (three headed spears) as weapons. They are said to be able command creatures in the sea, and ride upon seahorses.
165
+ --SEPARATOR--
166
+ Nymphs are creatures that inhabit the most secluded and peaceful areas of the world, close to a pure water source e.g. a spring. Nymphs are a peaceful race they hate evil and ugliness. They try to avoid conflict and they flee if challenged. When the situation is desperate a nymph will defend themselves by using their abilities to blind or confuse there attacker using there magical powers. Nymphs look a lot like humans apart from they have pointy ears like elves, they are nature's embodiment of physical beauty. Nymphs are ever young they are charming graceful and intelligent.
167
+ --SEPARATOR--
168
+ Orc's are green skinned creatures with large sharp K9 teeth and pointy ears. They are taller than humans are but Orc's are broader and more muscular. Orc's have larger heads than humans but have thicker skulls and they have smaller brains. Another biological feature of Orc's is their ability to suffer large wounds that wound almost certainly kill humans or other races. They seem to possess a remarkable regenerative & immune system. Orc's are an aggressive race that forms tribes or war bands. Sometimes these war bands rally together to form mighty armies. Orc's are savage fighters and preferred tactics are to attack in large hordes. Orc's use varying weapons to fight with depending on the intelligence within a particular tribe.
169
+
170
+ Savage Orc's which is a tribe that inhabit forests are poorer developed and tend to use clubs and simple weapons to attack with and do not use armour. A more advanced tribe would be equipped forged weapons and armour. Orc's use an unusual mount to get them around quickly they use giant ferocious Boars. These make good attacking cavalry.
171
+ --SEPARATOR--
172
+ The Ratmen are mutated race that originate from rats, which fed upon some source of mighty magical power. They are the same height approximately as goblins. Ratmen are sneaky and rely on stealth to attack using underground tunnels. Ratmen are quick, they can be carriers for all sorts of diseases which the ratmen themselves seem to be immune to.
173
+ --SEPARATOR--
174
+ A Satyr or Faun are intelligent creatures that are found in the wild places of the world it is sad. They indulge in food and drink and romance. A satyr is a horned man with the legs and feet of a goat. The hair that grows on these creatures is usually chestnut brown. The hooves tend to be jet black in colour. Satyrs are mischievous by nature and like to play tricks on others. They have a natural talent for music; they often carry Pan Pipes to play their magical tunes on them.
175
+ --SEPARATOR--
176
+ Troglodytes are humanoid reptilian creatures they have basic intelligence. They stand about 5ft in height, have a basic humanoid shape, they have three fingers and a thumb. The skin is tough leathery and they have lizard like tail and talons on their feet. Their skin pigment tones have the ability to change colour to match their surroundings much like that of a chameleon.
177
+
178
+ Troglodytes are a feudal race having warring tribes. They have very basic technology; they construct crude weapons such as flint axes and javelins. They are ferocious creatures that hunt down other creatures for meat. They were said to raid human settlements to capture the young for food.
179
+ --SEPARATOR--
180
+ Wood Elves are similar to High Elves and Drow they are a slender noble race with pale skin and often pointed ears. Wood elves live deep with the forests and are guardians of the natural world. Wood elves work with the forest they try not to damage the forest like other races do. Wood elves are extremely agile, there main method of attack is the bow, but they will use hand-to-hand fighting. Wood elves like to use stealth as there are fewer of them compared to the High Elves and Dark Elves. Wood Elves prefer not to wear armour this allows them to move more rapidly through forests though they tend to wear fabrics in greens and browns to match in with their surroundings, this helps them to move around unnoticed.
181
+ --SEPARATOR--
182
+ Alan are Half Human, Half bird like creatures that originate from the Philippines. They live within the forests and spend much of there time hanging upside down from trees. This is due to that they have the fingers on feet and hands have toes on the tips. They are mischievous in nature but are friendly to humans and some legends refer to them bringing up children that have lost their parents.
183
+ --SEPARATOR--
184
+ The Anka is a gigantic Arabian bird that can live over 1700 years in age. The wingspan of the Anka is said to be the breadth of five elephants.
185
+ --SEPARATOR--
186
+ The Bar Juchne is another enormous bird with a huge wingspan. This bird is similar to the Roc. This bird is mentioned in Hebrew texts and one incident is written about when an egg fell from its mountain nest, it felled 300 trees and caused flooding in 600 villages.
187
+ --SEPARATOR--
188
+ The Benu comes from ancient Egyptian mythology. It is a bird that was born at creation and was worshiped at the city of Heliopolis. It is a heron like bird with a crest of two long feathers that come from the rear of its head. The Benu is reborn each day after its journey through the night. The Benu is associated with the gods of the sun; it accompanies the souls of the dead on the boat of Ra on its journey through the underworld.
189
+
190
+ The Benu was the first creature to come from the sea at creation and the first piece of land it stood on a temple was built this is said to be Heliopolis.
191
+
192
+ It is thought that the Benu bird is much like that of the Phoenix. The Greek Historian Herodotus visited Egypt and described the bird that the priests showed him as the western bird Phoenix.
193
+ --SEPARATOR--
194
+ The Feng-Huang is the Chinese equivalent of the Phoenix. The Feng-Huang lives with the sun. The Bird has the head and comb of a pheasant and the feathers of a peacock. The bird has three legs not two. The plumage blends the five colours and the song of the bird is a harmony of five notes. A male of this bird is the Feng the female is the Huang.
195
+ --SEPARATOR--
196
+ The firebird (Zshar-ptitsa) comes from Russian folklore it is a miraculous bird, its feathers shine silver and gold, its eyes sparkle like crystals. The firebird is a nocturnal bird as at night this bird illuminates the land that it flies over. It is said that a feather from its tail can light up a dark room. The firebird eats golden apples. When the firebird sings it is said that pearls would fall from its beak. The firebird was also able to heal the sick and cure the blind by its chants.
197
+ --SEPARATOR--
198
+ The Garuda or Garida is an ancient bird, which is similar to that of the Roc, it could block out the sun with its wings and could pick up elephants in its talons. The Garuda has a beak like that of an wings of a bird of prey, the head was white in colour and the wings scarlet and the body was a golden yellow colour. It was sometimes called the bird of life. The Garuda was the feared enemy of the Nagas.
199
+ --SEPARATOR--
200
+ Great eagles are fantasy creatures that are normally based on the golden eagle, but there size much greater, with wing spans of six or seven meters across. These birds are often seen as allies of good and are associated with the High Elves and Wood Elves. They act as information carriers and they are sometimes seen as sentient creatures with high intelligence.
201
+ --SEPARATOR--
202
+ Harpies have two descriptions their origins both lay within Greek mythology. The first description is that Harpies are beautiful winged maidens; the second is that they are winged monsters with the head of an ugly old crone that had sharp talons. Both descriptions were seen to do the same thing, which was that they carried people off to the underworld and inflicting punishment upon them and ultimately death.
203
+ --SEPARATOR--
204
+ The Hsigo come from Chinese mythology and are in fact monkeys with bird like wings. They are much like the monkeys seen in the film Wizard of Oz.
205
+ --SEPARATOR--
206
+ The Peryton lived in Atlantis, it is a large flying creature with the antlered head and legs of a dear with the wings and body of bird. The Peryton casts a shadow of a man until it kills one. When an earthquake destroyed Atlantis the Peryton took to the air and were seen flying above the Pillars of Hercules.
207
+ --SEPARATOR--
208
+ The Phoenix is a mythical bird that is associated with the Egyptian sun god Ra and the Greek sun god Apollo. The Phoenix lives in Arabia, near cool oases. The Phoenix each morning bathes in the pool and sings beautiful bird song. There can only be one Phoenix a live at any one time. When the Phoenix feels its old and it's time to die, the Phoenix builds a nest then settles on it then sets it on fire. The Phoenix burns to death in the flames. From theses flames a new Phoenix emerges. The Phoenix symbolises immortality, resurrection and life after death, because of this it's often seen on sarcophagi and tombs.
209
+ --SEPARATOR--
210
+ The Roc is a legendary gigantic bird from Arabian legends. These birds were so big that they could carry off elephants for food. The Roc is featured in various stories of the "Thousand and One Nights" and they have also featured in historical texts of Marco Polo on his travels.
211
+ --SEPARATOR--
212
+ The Simurgh lives on the tip of Mount Alburz in Persia. The Simurgh is an all-knowing bird with the powers of reason and speech. It is said to have had a discussion with King Solomon about Philisphy. The Simurgh is a very large bird its feathers are said to have curative powers. The Simurgh was said to have nested in the tree of knowledge.
213
+ --SEPARATOR--
214
+ Stymphalids are similar in appearance and size to Cranes. They are man eating birds which their beaks and claws are made of brass, that can pierce armour. Their feathers are also made from brass which they can fire at men like arrows.
215
+
216
+ They are said to have come from the marshes in Arcadia. Hercules as one of his twelve labours was to rod the marshes of Stymphalids he did this by using brass rattles that scared the Stymphalids into the air where Hercules shot then down with arrows. Some of the Stymphalids escaped to the island of the war god Ares, where Jason and his crew encountered them on their way to collect the Golden Fleece. Jason's crew scared them off by banging their swords against their brass shields.
217
+ --SEPARATOR--
218
+ The Thunderbird lives among the clouds in the sky. The Thunderbird is a monstrous sized bird that is so big it can take whales within its talons. When the Thunderbird beats it wings they crackle with lightening and rumble with thunder. Arrows of lightening flash from its eyes. The Thunderbird comes from Native American mythology.
219
+ --SEPARATOR--
220
+ The Ziz is a mythical bird mentioned in the book of Psalms. It is of enormous size much like that of the Roc. It wings are so vast it can block out the sunlight. The Ziz is the protector of small birds and it is said without the Ziz the small birds would have died out many years before. The Ziz is said to be one of the meats served at the humanities last meal.
221
+ --SEPARATOR--
222
+ An Assassin Vine is a semi-mobile plant that collects its own nutrients by capturing other creatures and depositing their remains near their roots. The main vine of a fully-grown plant is about 20ft in length and 5 inches in thickness, smaller vines shoot off from the main vine stem. The assassin vines fruit resembles that of wild grapes, but have tougher skins and are poisonous to human and larger mammals.
223
+
224
+ The Assassin vine entangles its prey if they happen to wander into the plant. The Assassin vine automatically reacts to constrict around the ensnared animal. The vine will constrict its victim to death.
225
+ --SEPARATOR--
226
+ Treemen are ancient fantasy creatures that protect the forests of old. Treemen are gigantic tree looking like creatures they have trunk like legs and have great branches for arms. They have a tough woody skin. Treemen have a fear of fire as this can do much damage to them and their forest homes. Treemen tend to have hatred towards those that go around chopping down the biggest and oldest of trees.
227
+ --SEPARATOR--
228
+ Black dragons can be found in marshes and underground cave networks. Black dragons are cunning and evil tempered. They usually side with evil. Black dragons are sometimes known as skull dragons due to their deeply set eye sockets. The dragons have black or very dark grey scales that are glossy when young and with age become duller. Black dragon's smell of rotting vegetation and stagnant water.
229
+
230
+ Black dragons preferred diet is fish and shellfish they do eat other animals but they like to let these animals once killing rest in stagnant ponds first before eating some days later. Black dragons like to collect coins as there treasure. Black dragons preferred attack method is to ambush their target by making use of their surroundings and vegetation as cover. Black dragons have two different breath attacks. The first is an acid breath attack, which is used to dissolve and blind its attacker/prey. The second breath attack is an oily black smoke attack which can be used to choke its attacker/prey.
231
+
232
+ Dragons are magical creatures and have the ability to cast spells black dragons are able to perform the following spells.
233
+
234
+ Darkness - this causes the local area to become dark.
235
+ Insect Plague - causes a plague of insects to attack an area.
236
+ Corrupt water - this spell can stagnate water.
237
+ Charm reptiles - This allows the black dragon to control weaker reptiles to do its bidding.
238
+ --SEPARATOR--
239
+ Blue dragons live in temperate and warm desert environments also they can be found underground. Blue dragons are very territorial. They are well adapted for digging into sand and soft soils. Blue dragons have frilled ears and a large single horn on their snout.
240
+
241
+ Blue dragons vary in colour from dark blue to light blues; their scales are polished. Blue dragon's scales crackle as they built up static energy. Blue dragons like to soar high above the deserts; Blue dragons like to collect gems as treasure they are particularly attracted by sapphires.
242
+
243
+ They eat red meat, which they usually cook first before eating. Their attack methods involve attacking from the sky and diving on their target or burrowing into the sand and wait for the prey to come to them and attack quickly catching there target unawares. Blue dragon's breath attack has a lightening breath attack.
244
+
245
+ As dragons are magical creatures the blue dragon is able to cast the following spells.
246
+
247
+ Create/destroy water
248
+ Sound imitation. Blue dragons are able to cast this to allow them to imitate other creature's sounds or speech.
249
+ Illusionary terrain the dragon is able to create images of terrain that appears to be there but is not.
250
+ --SEPARATOR--
251
+ Bronze dragons live in temperate and warm aquatic climates as well as underground. Bronze dragons are inquisitive creatures they like to observe other creatures acting about their daily routines and business. Bronze dragons do this by using a skill called polymorphying, which allows them to assume the appearance of other living creatures.
252
+
253
+ Young bronze dragons scales and skin tones are yellow tinged with green, as the dragon grows older its colours deepen and develops into a bronze tone. Bronze dragons are well adapted to swimming and can breathe underwater unhindered. This swimming adaptation is due to their diets as bronze dragons mainly eat marine or freshwater creatures and aquatic plants.
254
+
255
+ Bronze dragons dislike killing attackers and would rather bribe or force them away magically. Bronze dragons are armed with two types of breath attack, the first been a lightening breath attack which they can electrocute there attacker the second is a gas breath that repulses the attacker away. Bronze dragons are magical creatures and are able to cast the following spells.
256
+
257
+ Fog cloud - dragon is able to bring down a cloud of fog to confuse and disorientate its attackers.
258
+ Control weather - The dragon is able to control the weather in the local area.
259
+ Copyright
260
+ --SEPARATOR--
261
+ Gold dragons will live in any climate. Gold dragons are associated with the side of good. Gold dragons are seen to be graceful and wise creatures, they dislike injustice and foul play, and often take it upon themselves to put these things right.
262
+
263
+ Young gold dragons start out with dark yellow scales and skin tones, as the Dragon ages the scales become more golden. Gold dragons are often accompanied by loyal guards, which can be made up of many types of animals, but often giants.
264
+
265
+ Gold dragons prefer to use spells in combat rather than physical fighting. Gold dragons have two forms of breath attack; the first of these is a fire breath that burns attackers/prey. The second breath attack is a gas that weakens opponents. Gold dragons are highly magical creatures and are able to perform the following spells.
266
+
267
+ Cloud Kill
268
+ Fireball - a ball of magical fire
269
+ Fire Shield - a wall of fire used to protect the dragon
270
+ Shield - an invisible wall of energy to protect the dragon
271
+ Sleep - the dragon is able to put his enemies asleep
272
+ Stinking cloud - cloud of foul gas
273
+ Slow - reduces the speed of the dragons attackers
274
+ --SEPARATOR--
275
+ The Chinese Dragons are different from those that originate in western culture. The dragon in china is powerful but benevolent ruler of the earth, sky and sea. It is one of the four spiritual animals and Chinese dragons are said to bring good fortune.
276
+
277
+ Chinese dragons are different in description in classic western dragons they have a horse like head with two horns from behind their ears. They have long whiskers that come away from its muzzle. They have thick scaly hides like that of western dragons. They have four taloned legs but Chinese dragons don't have wings though they are still able to fly. Their bodies are long and snake like.
278
+
279
+ Chinese dragons play with a ball of light known as the sacred pearl; this is thought to be the source of the dragon's power. There is a place in the sky called the jade tablet; this contains the number of the dragons. Chinese dragons develop over a number of years they begin as small water snakes, and grow through their lives. Chinese dragons have the power to polymorph (change their shape).
280
+
281
+ Lung is a named Chinese dragon and he is the ruler of the sky. It is said that Lung has the head of a camel, horns of a deer and the ears of an ox, the eyes of a devil, the neck of a snake, the belly of a clam, fish scales and eagle's talons and paws of a tiger. Lungs breath is said to form the clouds.
282
+
283
+ There are other Chinese dragons that rule the weather, rain, thunder, lightning and the wind. Other Chinese dragons rule the oceans and the earth. A dragon named Chiao is the supreme Dragon of the earth.
284
+ --SEPARATOR--
285
+ The dragons of Joppa are from around the southern and eastern Mediterranean. Two of these dragons are known as the most famous monsters of legend.
286
+
287
+ The first of these beasts was of the sea the second a creature of foul disease ridden swamps. The first called Cetus was a dragon that the Greek god Poseidon called from the sea to ravage the country of King Cepheus. Waters flooded his lands and Cetus ate his people. King Cepheus consulted the great Oracle and was told the only way to stop the dragon was to sacrifice his daughter Andromeda, to Cetus. As the legend goes the hero Perseus was returning home after slaying the Gorgon Medusa when he saw from afar Andromeda bound to a rock. As the dragon Cetus approached Andromeda, Perseus landed on the dragons back and drove his sword repeatedly into Cetus and slew the dragon, freeing the kingdom and winning himself a wife.
288
+
289
+ The second dragon at Joppa is from the accounts of English knights on the crusades they told of St George's battle with a dragon. St George was returning home when he came across a maiden in distress. She was the Princess of the kingdom chained to a post in the marshes. When the dragon emerged from the swamp to eat the princess George attacked driving his lance through the dragons opened jaws. Through this act George converted the local villages to Christianity. George became the patron saint of England.
290
+ --SEPARATOR--
291
+ The dragon called the Tararque was from the Rhone Valley in Southern France. This dragon does not display standard dragon features the monster had six legs the head of a lion the paws of a bear and a scaly body with a barbed tail. The creature's body was covered in sharp spikes. The dragon was said to sink ships that navigated the Rhone and killed those that travelled to close to the river banks.
292
+ --SEPARATOR--
293
+ In the west the dragon or known otherwise as the Drakon in Greek and Draco in roman the British refer to it as the Drake. Western dragons have the characteristics that many of the typical fantasy dragons do. They have four taloned feet, a pair of wings that are like that of a bat. Their heads have a crest and a beard underneath their chin. Some have horns or antlers. They have tough scales; their stomachs are like that of crocodiles. Western dragons come in varying sizes from small fox size to huge many times bigger than elephants. They come in a variety of colours and have some kind of breath attack most usually fire. Western dragons have barbed tongues.
294
+
295
+ It is said that the blood of these western dragons have powerful healing properties and the blood also allows the understanding of other languages. Western dragons have s gem in there head known as the Draconce or Dragon-stone it is a brilliant red it is said to have curative powers.
296
+ --SEPARATOR--
297
+ In northern European countries the Worm (Wyrm or Wurm) is found. It is one of the most ferocious dragons. They live deep beneath the earth, and they guard gems and treasure such as gold. These dragons are huge and usually wingless they are covered in thick hard scales that are said to be like the strength of steel. They often are fire breathers. Worms take vengeance on those that steal from their hordes.
298
+
299
+ One of the best known worms was the nemesis of the warrior king Beowulf. Beowulf slew the dragon but at the cost of his own life. The legend says that Beowulf was buried with the Worm's treasure.
300
+ --SEPARATOR--
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
eval_nocaps.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip import blip_decoder
26
+ import utils
27
+ from data import create_dataset, create_sampler, create_loader
28
+ from data.utils import save_result
29
+
30
+ @torch.no_grad()
31
+ def evaluate(model, data_loader, device, config):
32
+ # evaluate
33
+ model.eval()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ header = 'Evaluation:'
37
+ print_freq = 10
38
+
39
+ result = []
40
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
41
+
42
+ image = image.to(device)
43
+
44
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
45
+ min_length=config['min_length'], repetition_penalty=1.1)
46
+
47
+ for caption, img_id in zip(captions, image_id):
48
+ result.append({"image_id": img_id.item(), "caption": caption})
49
+
50
+ return result
51
+
52
+
53
+ def main(args, config):
54
+ utils.init_distributed_mode(args)
55
+
56
+ device = torch.device(args.device)
57
+
58
+ # fix the seed for reproducibility
59
+ seed = args.seed + utils.get_rank()
60
+ torch.manual_seed(seed)
61
+ np.random.seed(seed)
62
+ random.seed(seed)
63
+ cudnn.benchmark = True
64
+
65
+ #### Dataset ####
66
+ print("Creating captioning dataset")
67
+ val_dataset, test_dataset = create_dataset('nocaps', config)
68
+
69
+ if args.distributed:
70
+ num_tasks = utils.get_world_size()
71
+ global_rank = utils.get_rank()
72
+ samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
73
+ else:
74
+ samplers = [None,None]
75
+
76
+ val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
77
+ batch_size=[config['batch_size']]*2,num_workers=[4,4],
78
+ is_trains=[False, False], collate_fns=[None,None])
79
+
80
+ #### Model ####
81
+ print("Creating model")
82
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
83
+ prompt=config['prompt'])
84
+
85
+ model = model.to(device)
86
+
87
+ model_without_ddp = model
88
+ if args.distributed:
89
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
90
+ model_without_ddp = model.module
91
+
92
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
93
+ val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
94
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
95
+ test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
96
+
97
+
98
+ if __name__ == '__main__':
99
+ parser = argparse.ArgumentParser()
100
+ parser.add_argument('--config', default='./configs/nocaps.yaml')
101
+ parser.add_argument('--output_dir', default='output/NoCaps')
102
+ parser.add_argument('--device', default='cuda')
103
+ parser.add_argument('--seed', default=42, type=int)
104
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
105
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
106
+ parser.add_argument('--distributed', default=True, type=bool)
107
+ args = parser.parse_args()
108
+
109
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
110
+
111
+ args.result_dir = os.path.join(args.output_dir, 'result')
112
+
113
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
114
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
115
+
116
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
117
+
118
+ main(args, config)
eval_retrieval_video.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip_retrieval import blip_retrieval
26
+ import utils
27
+ from data.video_dataset import VideoDataset
28
+
29
+
30
+ @torch.no_grad()
31
+ def evaluation(model, data_loader, tokenizer, device, config):
32
+ # test
33
+ model.eval()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ header = 'Evaluation:'
37
+
38
+ print('Computing features for evaluation...')
39
+ start_time = time.time()
40
+
41
+ texts = data_loader.dataset.text
42
+ num_text = len(texts)
43
+ text_bs = 256
44
+ text_ids = []
45
+ text_embeds = []
46
+ text_atts = []
47
+ for i in range(0, num_text, text_bs):
48
+ text = texts[i: min(num_text, i+text_bs)]
49
+ text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
50
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
51
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
52
+ text_embeds.append(text_embed)
53
+ text_ids.append(text_input.input_ids)
54
+ text_atts.append(text_input.attention_mask)
55
+
56
+ text_embeds = torch.cat(text_embeds,dim=0)
57
+ text_ids = torch.cat(text_ids,dim=0)
58
+ text_atts = torch.cat(text_atts,dim=0)
59
+ text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
60
+
61
+ video_feats = []
62
+ video_embeds = []
63
+ for video, video_id in data_loader:
64
+
65
+ B,N,C,W,H = video.size()
66
+ video = video.view(-1,C,W,H)
67
+ video = video.to(device,non_blocking=True)
68
+ video_feat = model.visual_encoder(video)
69
+ video_embed = model.vision_proj(video_feat[:,0,:])
70
+ video_embed = video_embed.view(B,N,-1).mean(dim=1)
71
+ video_embed = F.normalize(video_embed,dim=-1)
72
+
73
+ video_feat = video_feat.view(B,-1,video_feat.shape[-1])
74
+ video_feats.append(video_feat.cpu())
75
+ video_embeds.append(video_embed)
76
+
77
+ video_feats = torch.cat(video_feats,dim=0)
78
+ video_embeds = torch.cat(video_embeds,dim=0)
79
+
80
+ sims_matrix = video_embeds @ text_embeds.t()
81
+ score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
82
+
83
+ num_tasks = utils.get_world_size()
84
+ rank = utils.get_rank()
85
+ step = sims_matrix.size(0)//num_tasks + 1
86
+ start = rank*step
87
+ end = min(sims_matrix.size(0),start+step)
88
+
89
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
90
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
91
+
92
+ encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
93
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
94
+ output = model.text_encoder(text_ids[topk_idx],
95
+ attention_mask = text_atts[topk_idx],
96
+ encoder_hidden_states = encoder_output,
97
+ encoder_attention_mask = encoder_att,
98
+ return_dict = True,
99
+ )
100
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
101
+ score_matrix_v2t[start+i,topk_idx] = score + topk_sim
102
+
103
+ sims_matrix = sims_matrix.t()
104
+ score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
105
+
106
+ step = sims_matrix.size(0)//num_tasks + 1
107
+ start = rank*step
108
+ end = min(sims_matrix.size(0),start+step)
109
+
110
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
111
+
112
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
113
+ encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
114
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
115
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
116
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
117
+ encoder_hidden_states = encoder_output,
118
+ encoder_attention_mask = encoder_att,
119
+ return_dict = True,
120
+ )
121
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
122
+ score_matrix_t2v[start+i,topk_idx] = score + topk_sim
123
+
124
+ if args.distributed:
125
+ dist.barrier()
126
+ torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
127
+ torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
128
+
129
+ total_time = time.time() - start_time
130
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
131
+ print('Evaluation time {}'.format(total_time_str))
132
+
133
+ return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
134
+
135
+
136
+
137
+ @torch.no_grad()
138
+ def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
139
+
140
+ #Video->Text
141
+ ranks = np.zeros(scores_v2t.shape[0])
142
+ for index,score in enumerate(scores_v2t):
143
+ inds = np.argsort(score)[::-1]
144
+ ranks[index] = np.where(inds == vid2txt[index])[0][0]
145
+
146
+ # Compute metrics
147
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
148
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
149
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
150
+
151
+ #Text->Video
152
+ ranks = np.zeros(scores_t2v.shape[0])
153
+
154
+ for index,score in enumerate(scores_t2v):
155
+ inds = np.argsort(score)[::-1]
156
+ ranks[index] = np.where(inds == txt2vmg[index])[0][0]
157
+
158
+ mdR = np.median(ranks+1)
159
+
160
+ # Compute metrics
161
+ vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
162
+ vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
163
+ vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
164
+
165
+ tr_mean = (tr1 + tr5 + tr10) / 3
166
+ vr_mean = (vr1 + vr5 + vr10) / 3
167
+ r_mean = (tr_mean + vr_mean) / 2
168
+
169
+ eval_result = {'txt_r1': tr1,
170
+ 'txt_r5': tr5,
171
+ 'txt_r10': tr10,
172
+ 'txt_r_mean': tr_mean,
173
+ 'vid_r1': vr1,
174
+ 'vid_r5': vr5,
175
+ 'vid_r10': vr10,
176
+ 'vid_r_mean': vr_mean,
177
+ 'vid_mdR': mdR,
178
+ 'r_mean': r_mean}
179
+ return eval_result
180
+
181
+
182
+
183
+
184
+ def main(args, config):
185
+ utils.init_distributed_mode(args)
186
+
187
+ device = torch.device(args.device)
188
+
189
+ # fix the seed for reproducibility
190
+ seed = args.seed + utils.get_rank()
191
+ torch.manual_seed(seed)
192
+ np.random.seed(seed)
193
+ random.seed(seed)
194
+ cudnn.benchmark = True
195
+
196
+ #### Dataset ####
197
+ print("Creating retrieval dataset")
198
+ test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
199
+ max_img_size=config['image_size'], frm_sampling_strategy='uniform')
200
+
201
+ test_loader = DataLoader(
202
+ test_dataset,
203
+ batch_size=config['batch_size'],
204
+ num_workers=4,
205
+ pin_memory=True,
206
+ drop_last=False,
207
+ shuffle=False,
208
+ )
209
+
210
+ #### Model ####
211
+ print("Creating model")
212
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
213
+
214
+ model = model.to(device)
215
+
216
+ model_without_ddp = model
217
+ if args.distributed:
218
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
219
+ model_without_ddp = model.module
220
+
221
+ score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
222
+
223
+ if utils.is_main_process():
224
+
225
+ test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
226
+ print(test_result)
227
+
228
+ log_stats = {**{f'{k}': v for k, v in test_result.items()},}
229
+ with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
230
+ f.write(json.dumps(log_stats) + "\n")
231
+
232
+
233
+ if __name__ == '__main__':
234
+ parser = argparse.ArgumentParser()
235
+ parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
236
+ parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
237
+ parser.add_argument('--device', default='cuda')
238
+ parser.add_argument('--seed', default=42, type=int)
239
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
240
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
241
+ parser.add_argument('--distributed', default=True, type=bool)
242
+ args = parser.parse_args()
243
+
244
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
245
+
246
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
247
+
248
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
249
+
250
+ main(args, config)
imagecaptioning.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ import torch
4
+ from torchvision import transforms
5
+ import os
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.image as mpimg
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ from models.blip import blip_decoder
13
+
14
+ image_size = 384
15
+ transform = transforms.Compose([
16
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
19
+ ])
20
+
21
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
22
+
23
+ model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
24
+ model.eval()
25
+ model = model.to(device)
26
+
27
+
28
+ from models.blip_vqa import blip_vqa
29
+
30
+ image_size_vq = 480
31
+ transform_vq = transforms.Compose([
32
+ transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
35
+ ])
36
+
37
+ model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
38
+
39
+ model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
40
+ model_vq.eval()
41
+ model_vq = model_vq.to(device)
42
+
43
+
44
+
45
+ def inference(raw_image, model_n, question="", strategy=""):
46
+ if model_n == 'Image Captioning':
47
+ image = transform(raw_image).unsqueeze(0).to(device)
48
+ with torch.no_grad():
49
+ if strategy == "Beam search":
50
+ caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
51
+ else:
52
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
53
+ return 'caption: '+caption[0]
54
+
55
+ else:
56
+ image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
57
+ with torch.no_grad():
58
+ answer = model_vq(image_vq, question, train=False, inference='generate')
59
+ return 'answer: '+answer[0]
60
+
61
+ #get caption for a single iamge
62
+ def get_caption(image_path):
63
+ img = Image.open(image_path)
64
+ return inference(img, "Image Captioning")[9:]
65
+
66
+ def display(image_path):
67
+ img = mpimg.imread(image_path)
68
+ img = Image.open(image_path)
69
+ plt.imshow(img)
70
+ print("Caption: " + get_caption(image_path))
71
+
72
+ #returns a dictionary with key -> img_path and value -> caption
73
+ def get_captions(img_directory, print_status=True):
74
+ #key is img path, value is the caption
75
+ captions = {}
76
+ length = 0
77
+ for file in os.listdir(img_directory):
78
+ length+=1
79
+ count = 0
80
+ for file in os.listdir(img_directory):
81
+ f = os.path.join(img_directory, file)
82
+ captions[f] = inference(Image.open(f), "Image Captioning")
83
+ if print_status:
84
+ print("Images complete:", str(count) + "/" + str(length))
85
+ print("Caption:", captions[f])
86
+ return captions
87
+ #writes dictionary to file, key and value seperated by ':'
88
+ def write_to_file(filename, caption_dict):
89
+ with open(filename, "w") as file:
90
+ for i in caption_dict:
91
+ file.write(i + ":" + caption_dict[i])
92
+ file.close()
93
+
predict.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download the weights in ./checkpoints beforehand for fast inference
3
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
4
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
5
+ wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
6
+ """
7
+
8
+ from pathlib import Path
9
+
10
+ from PIL import Image
11
+ import torch
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import cog
15
+
16
+ from models.blip import blip_decoder
17
+ from models.blip_vqa import blip_vqa
18
+ from models.blip_itm import blip_itm
19
+
20
+
21
+ class Predictor(cog.Predictor):
22
+ def setup(self):
23
+ self.device = "cuda:0"
24
+
25
+ self.models = {
26
+ 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
27
+ image_size=384, vit='base'),
28
+ 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
29
+ image_size=480, vit='base'),
30
+ 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
31
+ image_size=384, vit='base')
32
+ }
33
+
34
+ @cog.input(
35
+ "image",
36
+ type=Path,
37
+ help="input image",
38
+ )
39
+ @cog.input(
40
+ "task",
41
+ type=str,
42
+ default='image_captioning',
43
+ options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
44
+ help="Choose a task.",
45
+ )
46
+ @cog.input(
47
+ "question",
48
+ type=str,
49
+ default=None,
50
+ help="Type question for the input image for visual question answering task.",
51
+ )
52
+ @cog.input(
53
+ "caption",
54
+ type=str,
55
+ default=None,
56
+ help="Type caption for the input image for image text matching task.",
57
+ )
58
+ def predict(self, image, task, question, caption):
59
+ if task == 'visual_question_answering':
60
+ assert question is not None, 'Please type a question for visual question answering task.'
61
+ if task == 'image_text_matching':
62
+ assert caption is not None, 'Please type a caption for mage text matching task.'
63
+
64
+ im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
65
+ model = self.models[task]
66
+ model.eval()
67
+ model = model.to(self.device)
68
+
69
+ if task == 'image_captioning':
70
+ with torch.no_grad():
71
+ caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
72
+ return 'Caption: ' + caption[0]
73
+
74
+ if task == 'visual_question_answering':
75
+ with torch.no_grad():
76
+ answer = model(im, question, train=False, inference='generate')
77
+ return 'Answer: ' + answer[0]
78
+
79
+ # image_text_matching
80
+ itm_output = model(im, caption, match_head='itm')
81
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
82
+ itc_score = model(im, caption, match_head='itc')
83
+ return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
84
+ f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'
85
+
86
+
87
+ def load_image(image, image_size, device):
88
+ raw_image = Image.open(str(image)).convert('RGB')
89
+
90
+ w, h = raw_image.size
91
+
92
+ transform = transforms.Compose([
93
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
94
+ transforms.ToTensor(),
95
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
96
+ ])
97
+ image = transform(raw_image).unsqueeze(0).to(device)
98
+ return image
pretrain.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip_pretrain import blip_pretrain
26
+ import utils
27
+ from utils import warmup_lr_schedule, step_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+
30
+ def train(model, data_loader, optimizer, epoch, device, config):
31
+ # train
32
+ model.train()
33
+
34
+ metric_logger = utils.MetricLogger(delimiter=" ")
35
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
36
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
37
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
38
+ metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
39
+
40
+ header = 'Train Epoch: [{}]'.format(epoch)
41
+ print_freq = 50
42
+
43
+ if config['laion_path']:
44
+ data_loader.dataset.reload_laion(epoch)
45
+
46
+ data_loader.sampler.set_epoch(epoch)
47
+
48
+ for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
49
+
50
+ if epoch==0:
51
+ warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
52
+
53
+ optimizer.zero_grad()
54
+
55
+ image = image.to(device,non_blocking=True)
56
+
57
+ # ramp up alpha in the first 2 epochs
58
+ alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
59
+
60
+ loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
61
+ loss = loss_ita + loss_itm + loss_lm
62
+
63
+ loss.backward()
64
+ optimizer.step()
65
+
66
+ metric_logger.update(loss_ita=loss_ita.item())
67
+ metric_logger.update(loss_itm=loss_itm.item())
68
+ metric_logger.update(loss_lm=loss_lm.item())
69
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
70
+
71
+
72
+ # gather the stats from all processes
73
+ metric_logger.synchronize_between_processes()
74
+ print("Averaged stats:", metric_logger.global_avg())
75
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
76
+
77
+
78
+ def main(args, config):
79
+ utils.init_distributed_mode(args)
80
+
81
+ device = torch.device(args.device)
82
+
83
+ # fix the seed for reproducibility
84
+ seed = args.seed + utils.get_rank()
85
+ torch.manual_seed(seed)
86
+ np.random.seed(seed)
87
+ random.seed(seed)
88
+ cudnn.benchmark = True
89
+
90
+ #### Dataset ####
91
+ print("Creating dataset")
92
+ datasets = [create_dataset('pretrain', config, min_scale=0.2)]
93
+ print('number of training samples: %d'%len(datasets[0]))
94
+
95
+ num_tasks = utils.get_world_size()
96
+ global_rank = utils.get_rank()
97
+ samplers = create_sampler(datasets, [True], num_tasks, global_rank)
98
+
99
+ data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
100
+
101
+ #### Model ####
102
+ print("Creating model")
103
+ model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
104
+ vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
105
+
106
+ model = model.to(device)
107
+
108
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
109
+
110
+ start_epoch = 0
111
+ if args.checkpoint:
112
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
113
+ state_dict = checkpoint['model']
114
+ model.load_state_dict(state_dict)
115
+
116
+ optimizer.load_state_dict(checkpoint['optimizer'])
117
+ start_epoch = checkpoint['epoch']+1
118
+ print('resume checkpoint from %s'%args.checkpoint)
119
+
120
+ model_without_ddp = model
121
+ if args.distributed:
122
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
123
+ model_without_ddp = model.module
124
+
125
+ print("Start training")
126
+ start_time = time.time()
127
+ for epoch in range(start_epoch, config['max_epoch']):
128
+
129
+ step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
130
+
131
+ train_stats = train(model, data_loader, optimizer, epoch, device, config)
132
+ if utils.is_main_process():
133
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
134
+ 'epoch': epoch,
135
+ }
136
+ save_obj = {
137
+ 'model': model_without_ddp.state_dict(),
138
+ 'optimizer': optimizer.state_dict(),
139
+ 'config': config,
140
+ 'epoch': epoch,
141
+ }
142
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
143
+
144
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
145
+ f.write(json.dumps(log_stats) + "\n")
146
+
147
+ dist.barrier()
148
+
149
+ total_time = time.time() - start_time
150
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
151
+ print('Training time {}'.format(total_time_str))
152
+
153
+
154
+ if __name__ == '__main__':
155
+ parser = argparse.ArgumentParser()
156
+ parser.add_argument('--config', default='./configs/pretrain.yaml')
157
+ parser.add_argument('--output_dir', default='output/Pretrain')
158
+ parser.add_argument('--checkpoint', default='')
159
+ parser.add_argument('--evaluate', action='store_true')
160
+ parser.add_argument('--device', default='cuda')
161
+ parser.add_argument('--seed', default=42, type=int)
162
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
163
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
164
+ parser.add_argument('--distributed', default=True, type=bool)
165
+ args = parser.parse_args()
166
+
167
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
168
+
169
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
170
+
171
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
172
+
173
+ main(args, config)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ PIL
6
+ torch
7
+ torchvision
8
+ cohere
9
+ gradio
sample.png ADDED
train_caption.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip import blip_decoder
26
+ import utils
27
+ from utils import cosine_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+ from data.utils import save_result, coco_caption_eval
30
+
31
+ def train(model, data_loader, optimizer, epoch, device):
32
+ # train
33
+ model.train()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
37
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
38
+ header = 'Train Caption Epoch: [{}]'.format(epoch)
39
+ print_freq = 50
40
+
41
+ for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
42
+ image = image.to(device)
43
+
44
+ loss = model(image, caption)
45
+
46
+ optimizer.zero_grad()
47
+ loss.backward()
48
+ optimizer.step()
49
+
50
+ metric_logger.update(loss=loss.item())
51
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
52
+
53
+ # gather the stats from all processes
54
+ metric_logger.synchronize_between_processes()
55
+ print("Averaged stats:", metric_logger.global_avg())
56
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
57
+
58
+
59
+ @torch.no_grad()
60
+ def evaluate(model, data_loader, device, config):
61
+ # evaluate
62
+ model.eval()
63
+
64
+ metric_logger = utils.MetricLogger(delimiter=" ")
65
+ header = 'Caption generation:'
66
+ print_freq = 10
67
+
68
+ result = []
69
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
70
+
71
+ image = image.to(device)
72
+
73
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
74
+ min_length=config['min_length'])
75
+
76
+ for caption, img_id in zip(captions, image_id):
77
+ result.append({"image_id": img_id.item(), "caption": caption})
78
+
79
+ return result
80
+
81
+
82
+ def main(args, config):
83
+ utils.init_distributed_mode(args)
84
+
85
+ device = torch.device(args.device)
86
+
87
+ # fix the seed for reproducibility
88
+ seed = args.seed + utils.get_rank()
89
+ torch.manual_seed(seed)
90
+ np.random.seed(seed)
91
+ random.seed(seed)
92
+ cudnn.benchmark = True
93
+
94
+ #### Dataset ####
95
+ print("Creating captioning dataset")
96
+ train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
97
+
98
+ if args.distributed:
99
+ num_tasks = utils.get_world_size()
100
+ global_rank = utils.get_rank()
101
+ samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
102
+ else:
103
+ samplers = [None, None, None]
104
+
105
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
106
+ batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
107
+ is_trains=[True, False, False], collate_fns=[None,None,None])
108
+
109
+ #### Model ####
110
+ print("Creating model")
111
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
112
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
113
+ prompt=config['prompt'])
114
+
115
+ model = model.to(device)
116
+
117
+ model_without_ddp = model
118
+ if args.distributed:
119
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
120
+ model_without_ddp = model.module
121
+
122
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
123
+
124
+ best = 0
125
+ best_epoch = 0
126
+
127
+ print("Start training")
128
+ start_time = time.time()
129
+ for epoch in range(0, config['max_epoch']):
130
+ if not args.evaluate:
131
+ if args.distributed:
132
+ train_loader.sampler.set_epoch(epoch)
133
+
134
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
135
+
136
+ train_stats = train(model, train_loader, optimizer, epoch, device)
137
+
138
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
139
+ val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
140
+
141
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
142
+ test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
143
+
144
+ if utils.is_main_process():
145
+ coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
146
+ coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
147
+
148
+ if args.evaluate:
149
+ log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
150
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
151
+ }
152
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
153
+ f.write(json.dumps(log_stats) + "\n")
154
+ else:
155
+ save_obj = {
156
+ 'model': model_without_ddp.state_dict(),
157
+ 'optimizer': optimizer.state_dict(),
158
+ 'config': config,
159
+ 'epoch': epoch,
160
+ }
161
+
162
+ if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
163
+ best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
164
+ best_epoch = epoch
165
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
166
+
167
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
168
+ **{f'val_{k}': v for k, v in coco_val.eval.items()},
169
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
170
+ 'epoch': epoch,
171
+ 'best_epoch': best_epoch,
172
+ }
173
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
174
+ f.write(json.dumps(log_stats) + "\n")
175
+
176
+ if args.evaluate:
177
+ break
178
+ dist.barrier()
179
+
180
+ total_time = time.time() - start_time
181
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
182
+ print('Training time {}'.format(total_time_str))
183
+
184
+
185
+ if __name__ == '__main__':
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument('--config', default='./configs/caption_coco.yaml')
188
+ parser.add_argument('--output_dir', default='output/Caption_coco')
189
+ parser.add_argument('--evaluate', action='store_true')
190
+ parser.add_argument('--device', default='cuda')
191
+ parser.add_argument('--seed', default=42, type=int)
192
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
193
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
194
+ parser.add_argument('--distributed', default=True, type=bool)
195
+ args = parser.parse_args()
196
+
197
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
198
+
199
+ args.result_dir = os.path.join(args.output_dir, 'result')
200
+
201
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
202
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
203
+
204
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
205
+
206
+ main(args, config)
train_nlvr.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+ import json
18
+ import pickle
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import DataLoader
24
+ import torch.backends.cudnn as cudnn
25
+ import torch.distributed as dist
26
+
27
+ from models.blip_nlvr import blip_nlvr
28
+
29
+ import utils
30
+ from utils import cosine_lr_schedule, warmup_lr_schedule
31
+ from data import create_dataset, create_sampler, create_loader
32
+
33
+ def train(model, data_loader, optimizer, epoch, device, config):
34
+ # train
35
+ model.train()
36
+
37
+ metric_logger = utils.MetricLogger(delimiter=" ")
38
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
39
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
40
+
41
+ header = 'Train Epoch: [{}]'.format(epoch)
42
+ print_freq = 50
43
+ step_size = 10
44
+
45
+ for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
46
+
47
+ images = torch.cat([image0, image1], dim=0)
48
+ images, targets = images.to(device), targets.to(device)
49
+
50
+ loss = model(images, text, targets=targets, train=True)
51
+
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+
56
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
57
+ metric_logger.update(loss=loss.item())
58
+
59
+ # gather the stats from all processes
60
+ metric_logger.synchronize_between_processes()
61
+ print("Averaged stats:", metric_logger.global_avg())
62
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
63
+
64
+
65
+ @torch.no_grad()
66
+ def evaluate(model, data_loader, device, config):
67
+ # test
68
+ model.eval()
69
+
70
+ metric_logger = utils.MetricLogger(delimiter=" ")
71
+
72
+ header = 'Evaluation:'
73
+ print_freq = 50
74
+
75
+ for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
76
+ images = torch.cat([image0, image1], dim=0)
77
+ images, targets = images.to(device), targets.to(device)
78
+
79
+ prediction = model(images, text, targets=targets, train=False)
80
+
81
+ _, pred_class = prediction.max(1)
82
+ accuracy = (targets==pred_class).sum() / targets.size(0)
83
+
84
+ metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
85
+
86
+ # gather the stats from all processes
87
+ metric_logger.synchronize_between_processes()
88
+
89
+ print("Averaged stats:", metric_logger.global_avg())
90
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
91
+
92
+
93
+
94
+ def main(args, config):
95
+ utils.init_distributed_mode(args)
96
+
97
+ device = torch.device(args.device)
98
+
99
+ # fix the seed for reproducibility
100
+ seed = args.seed + utils.get_rank()
101
+ torch.manual_seed(seed)
102
+ np.random.seed(seed)
103
+ random.seed(seed)
104
+ cudnn.benchmark = True
105
+
106
+ #### Dataset ####
107
+ print("Creating dataset")
108
+ datasets = create_dataset('nlvr', config)
109
+
110
+ if args.distributed:
111
+ num_tasks = utils.get_world_size()
112
+ global_rank = utils.get_rank()
113
+ samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
114
+ else:
115
+ samplers = [None, None, None]
116
+
117
+ batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
118
+ train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
119
+ num_workers=[4,4,4],is_trains=[True,False,False],
120
+ collate_fns=[None,None,None])
121
+
122
+ #### Model ####
123
+ print("Creating model")
124
+ model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
125
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
126
+
127
+ model = model.to(device)
128
+
129
+ model_without_ddp = model
130
+ if args.distributed:
131
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
132
+ model_without_ddp = model.module
133
+
134
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
135
+
136
+ print("Start training")
137
+ start_time = time.time()
138
+ best = 0
139
+ best_epoch = 0
140
+
141
+ for epoch in range(0, config['max_epoch']):
142
+ if not args.evaluate:
143
+ if args.distributed:
144
+ train_loader.sampler.set_epoch(epoch)
145
+
146
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
147
+
148
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
149
+
150
+ val_stats = evaluate(model, val_loader, device, config)
151
+ test_stats = evaluate(model, test_loader, device, config)
152
+
153
+ if utils.is_main_process():
154
+ if args.evaluate:
155
+ log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
156
+ **{f'test_{k}': v for k, v in test_stats.items()},
157
+ }
158
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
159
+ f.write(json.dumps(log_stats) + "\n")
160
+
161
+ else:
162
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
163
+ **{f'val_{k}': v for k, v in val_stats.items()},
164
+ **{f'test_{k}': v for k, v in test_stats.items()},
165
+ 'epoch': epoch,
166
+ }
167
+
168
+ if float(val_stats['acc'])>best:
169
+ save_obj = {
170
+ 'model': model_without_ddp.state_dict(),
171
+ 'optimizer': optimizer.state_dict(),
172
+ 'config': config,
173
+ 'epoch': epoch,
174
+ }
175
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
176
+ best = float(val_stats['acc'])
177
+ best_epoch = epoch
178
+
179
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
180
+ f.write(json.dumps(log_stats) + "\n")
181
+ if args.evaluate:
182
+ break
183
+
184
+ dist.barrier()
185
+
186
+ if utils.is_main_process():
187
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
188
+ f.write("best epoch: %d"%best_epoch)
189
+
190
+ total_time = time.time() - start_time
191
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
192
+ print('Training time {}'.format(total_time_str))
193
+
194
+
195
+ if __name__ == '__main__':
196
+ parser = argparse.ArgumentParser()
197
+ parser.add_argument('--config', default='./configs/nlvr.yaml')
198
+ parser.add_argument('--output_dir', default='output/NLVR')
199
+ parser.add_argument('--evaluate', action='store_true')
200
+ parser.add_argument('--device', default='cuda')
201
+ parser.add_argument('--seed', default=42, type=int)
202
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
203
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
204
+ parser.add_argument('--distributed', default=True, type=bool)
205
+ args = parser.parse_args()
206
+
207
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
208
+
209
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
210
+
211
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
212
+
213
+ main(args, config)
train_retrieval.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip_retrieval import blip_retrieval
26
+ import utils
27
+ from utils import cosine_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+
30
+
31
+ def train(model, data_loader, optimizer, epoch, device, config):
32
+ # train
33
+ model.train()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
37
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
38
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
39
+ header = 'Train Epoch: [{}]'.format(epoch)
40
+ print_freq = 50
41
+
42
+ for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
43
+ image = image.to(device,non_blocking=True)
44
+ idx = idx.to(device,non_blocking=True)
45
+
46
+ if epoch>0:
47
+ alpha = config['alpha']
48
+ else:
49
+ alpha = config['alpha']*min(1,i/len(data_loader))
50
+
51
+ loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
52
+ loss = loss_ita + loss_itm
53
+
54
+ optimizer.zero_grad()
55
+ loss.backward()
56
+ optimizer.step()
57
+
58
+ metric_logger.update(loss_itm=loss_itm.item())
59
+ metric_logger.update(loss_ita=loss_ita.item())
60
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
61
+
62
+ # gather the stats from all processes
63
+ metric_logger.synchronize_between_processes()
64
+ print("Averaged stats:", metric_logger.global_avg())
65
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
66
+
67
+
68
+ @torch.no_grad()
69
+ def evaluation(model, data_loader, device, config):
70
+ # test
71
+ model.eval()
72
+
73
+ metric_logger = utils.MetricLogger(delimiter=" ")
74
+ header = 'Evaluation:'
75
+
76
+ print('Computing features for evaluation...')
77
+ start_time = time.time()
78
+
79
+ texts = data_loader.dataset.text
80
+ num_text = len(texts)
81
+ text_bs = 256
82
+ text_ids = []
83
+ text_embeds = []
84
+ text_atts = []
85
+ for i in range(0, num_text, text_bs):
86
+ text = texts[i: min(num_text, i+text_bs)]
87
+ text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
88
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
89
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
90
+ text_embeds.append(text_embed)
91
+ text_ids.append(text_input.input_ids)
92
+ text_atts.append(text_input.attention_mask)
93
+
94
+ text_embeds = torch.cat(text_embeds,dim=0)
95
+ text_ids = torch.cat(text_ids,dim=0)
96
+ text_atts = torch.cat(text_atts,dim=0)
97
+ text_ids[:,0] = model.tokenizer.enc_token_id
98
+
99
+ image_feats = []
100
+ image_embeds = []
101
+ for image, img_id in data_loader:
102
+ image = image.to(device)
103
+ image_feat = model.visual_encoder(image)
104
+ image_embed = model.vision_proj(image_feat[:,0,:])
105
+ image_embed = F.normalize(image_embed,dim=-1)
106
+
107
+ image_feats.append(image_feat.cpu())
108
+ image_embeds.append(image_embed)
109
+
110
+ image_feats = torch.cat(image_feats,dim=0)
111
+ image_embeds = torch.cat(image_embeds,dim=0)
112
+
113
+ sims_matrix = image_embeds @ text_embeds.t()
114
+ score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
115
+
116
+ num_tasks = utils.get_world_size()
117
+ rank = utils.get_rank()
118
+ step = sims_matrix.size(0)//num_tasks + 1
119
+ start = rank*step
120
+ end = min(sims_matrix.size(0),start+step)
121
+
122
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
123
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
124
+
125
+ encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
126
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
127
+ output = model.text_encoder(text_ids[topk_idx],
128
+ attention_mask = text_atts[topk_idx],
129
+ encoder_hidden_states = encoder_output,
130
+ encoder_attention_mask = encoder_att,
131
+ return_dict = True,
132
+ )
133
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
134
+ score_matrix_i2t[start+i,topk_idx] = score + topk_sim
135
+
136
+ sims_matrix = sims_matrix.t()
137
+ score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
138
+
139
+ step = sims_matrix.size(0)//num_tasks + 1
140
+ start = rank*step
141
+ end = min(sims_matrix.size(0),start+step)
142
+
143
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
144
+
145
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
146
+ encoder_output = image_feats[topk_idx].to(device)
147
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
148
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
149
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
150
+ encoder_hidden_states = encoder_output,
151
+ encoder_attention_mask = encoder_att,
152
+ return_dict = True,
153
+ )
154
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
155
+ score_matrix_t2i[start+i,topk_idx] = score + topk_sim
156
+
157
+ if args.distributed:
158
+ dist.barrier()
159
+ torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
160
+ torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
161
+
162
+ total_time = time.time() - start_time
163
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164
+ print('Evaluation time {}'.format(total_time_str))
165
+
166
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
167
+
168
+
169
+
170
+ @torch.no_grad()
171
+ def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
172
+
173
+ #Images->Text
174
+ ranks = np.zeros(scores_i2t.shape[0])
175
+ for index,score in enumerate(scores_i2t):
176
+ inds = np.argsort(score)[::-1]
177
+ # Score
178
+ rank = 1e20
179
+ for i in img2txt[index]:
180
+ tmp = np.where(inds == i)[0][0]
181
+ if tmp < rank:
182
+ rank = tmp
183
+ ranks[index] = rank
184
+
185
+ # Compute metrics
186
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
187
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
188
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
189
+
190
+ #Text->Images
191
+ ranks = np.zeros(scores_t2i.shape[0])
192
+
193
+ for index,score in enumerate(scores_t2i):
194
+ inds = np.argsort(score)[::-1]
195
+ ranks[index] = np.where(inds == txt2img[index])[0][0]
196
+
197
+ # Compute metrics
198
+ ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
199
+ ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
200
+ ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
201
+
202
+ tr_mean = (tr1 + tr5 + tr10) / 3
203
+ ir_mean = (ir1 + ir5 + ir10) / 3
204
+ r_mean = (tr_mean + ir_mean) / 2
205
+
206
+ eval_result = {'txt_r1': tr1,
207
+ 'txt_r5': tr5,
208
+ 'txt_r10': tr10,
209
+ 'txt_r_mean': tr_mean,
210
+ 'img_r1': ir1,
211
+ 'img_r5': ir5,
212
+ 'img_r10': ir10,
213
+ 'img_r_mean': ir_mean,
214
+ 'r_mean': r_mean}
215
+ return eval_result
216
+
217
+
218
+ def main(args, config):
219
+ utils.init_distributed_mode(args)
220
+
221
+ device = torch.device(args.device)
222
+
223
+ # fix the seed for reproducibility
224
+ seed = args.seed + utils.get_rank()
225
+ torch.manual_seed(seed)
226
+ np.random.seed(seed)
227
+ random.seed(seed)
228
+ cudnn.benchmark = True
229
+
230
+ #### Dataset ####
231
+ print("Creating retrieval dataset")
232
+ train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
233
+
234
+ if args.distributed:
235
+ num_tasks = utils.get_world_size()
236
+ global_rank = utils.get_rank()
237
+ samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
238
+ else:
239
+ samplers = [None, None, None]
240
+
241
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
242
+ batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
243
+ num_workers=[4,4,4],
244
+ is_trains=[True, False, False],
245
+ collate_fns=[None,None,None])
246
+
247
+
248
+ #### Model ####
249
+ print("Creating model")
250
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
251
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
252
+ queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
253
+
254
+ model = model.to(device)
255
+
256
+ model_without_ddp = model
257
+ if args.distributed:
258
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
259
+ model_without_ddp = model.module
260
+
261
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
262
+
263
+ best = 0
264
+ best_epoch = 0
265
+
266
+ print("Start training")
267
+ start_time = time.time()
268
+
269
+ for epoch in range(0, config['max_epoch']):
270
+ if not args.evaluate:
271
+ if args.distributed:
272
+ train_loader.sampler.set_epoch(epoch)
273
+
274
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
275
+
276
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
277
+
278
+ score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
279
+ score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
280
+
281
+ if utils.is_main_process():
282
+
283
+ val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
284
+ print(val_result)
285
+
286
+ if val_result['r_mean']>best:
287
+ save_obj = {
288
+ 'model': model_without_ddp.state_dict(),
289
+ 'optimizer': optimizer.state_dict(),
290
+ 'config': config,
291
+ 'epoch': epoch,
292
+ }
293
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
294
+ best = val_result['r_mean']
295
+ best_epoch = epoch
296
+
297
+ test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
298
+ print(test_result)
299
+
300
+ if args.evaluate:
301
+ log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
302
+ **{f'test_{k}': v for k, v in test_result.items()},
303
+ }
304
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
305
+ f.write(json.dumps(log_stats) + "\n")
306
+ else:
307
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
308
+ **{f'val_{k}': v for k, v in val_result.items()},
309
+ **{f'test_{k}': v for k, v in test_result.items()},
310
+ 'epoch': epoch,
311
+ 'best_epoch': best_epoch,
312
+ }
313
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
314
+ f.write(json.dumps(log_stats) + "\n")
315
+
316
+ if args.evaluate:
317
+ break
318
+
319
+ dist.barrier()
320
+ torch.cuda.empty_cache()
321
+
322
+ total_time = time.time() - start_time
323
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
324
+ print('Training time {}'.format(total_time_str))
325
+
326
+
327
+ if __name__ == '__main__':
328
+ parser = argparse.ArgumentParser()
329
+ parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
330
+ parser.add_argument('--output_dir', default='output/Retrieval_flickr')
331
+ parser.add_argument('--evaluate', action='store_true')
332
+ parser.add_argument('--device', default='cuda')
333
+ parser.add_argument('--seed', default=42, type=int)
334
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
335
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
336
+ parser.add_argument('--distributed', default=True, type=bool)
337
+ args = parser.parse_args()
338
+
339
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
340
+
341
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
342
+
343
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
344
+
345
+ main(args, config)
train_vqa.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import DataLoader
22
+ import torch.backends.cudnn as cudnn
23
+ import torch.distributed as dist
24
+
25
+ from models.blip_vqa import blip_vqa
26
+ import utils
27
+ from utils import cosine_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+ from data.vqa_dataset import vqa_collate_fn
30
+ from data.utils import save_result
31
+
32
+
33
+ def train(model, data_loader, optimizer, epoch, device):
34
+ # train
35
+ model.train()
36
+
37
+ metric_logger = utils.MetricLogger(delimiter=" ")
38
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
39
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
40
+
41
+ header = 'Train Epoch: [{}]'.format(epoch)
42
+ print_freq = 50
43
+
44
+ for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45
+ image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
46
+
47
+ loss = model(image, question, answer, train=True, n=n, weights=weights)
48
+
49
+ optimizer.zero_grad()
50
+ loss.backward()
51
+ optimizer.step()
52
+
53
+ metric_logger.update(loss=loss.item())
54
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
55
+
56
+ # gather the stats from all processes
57
+ metric_logger.synchronize_between_processes()
58
+ print("Averaged stats:", metric_logger.global_avg())
59
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
60
+
61
+
62
+ @torch.no_grad()
63
+ def evaluation(model, data_loader, device, config) :
64
+ # test
65
+ model.eval()
66
+
67
+ metric_logger = utils.MetricLogger(delimiter=" ")
68
+ header = 'Generate VQA test result:'
69
+ print_freq = 50
70
+
71
+ result = []
72
+
73
+ if config['inference']=='rank':
74
+ answer_list = data_loader.dataset.answer_list
75
+ answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
76
+ answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
77
+
78
+ for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
79
+ image = image.to(device,non_blocking=True)
80
+
81
+ if config['inference']=='generate':
82
+ answers = model(image, question, train=False, inference='generate')
83
+
84
+ for answer, ques_id in zip(answers, question_id):
85
+ ques_id = int(ques_id.item())
86
+ result.append({"question_id":ques_id, "answer":answer})
87
+
88
+ elif config['inference']=='rank':
89
+ answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
90
+
91
+ for ques_id, answer_id in zip(question_id, answer_ids):
92
+ result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
93
+
94
+ return result
95
+
96
+
97
+ def main(args, config):
98
+ utils.init_distributed_mode(args)
99
+
100
+ device = torch.device(args.device)
101
+
102
+ # fix the seed for reproducibility
103
+ seed = args.seed + utils.get_rank()
104
+ torch.manual_seed(seed)
105
+ np.random.seed(seed)
106
+ random.seed(seed)
107
+ cudnn.benchmark = True
108
+
109
+ #### Dataset ####
110
+ print("Creating vqa datasets")
111
+ datasets = create_dataset('vqa', config)
112
+
113
+ if args.distributed:
114
+ num_tasks = utils.get_world_size()
115
+ global_rank = utils.get_rank()
116
+ samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
117
+ else:
118
+ samplers = [None, None]
119
+
120
+ train_loader, test_loader = create_loader(datasets,samplers,
121
+ batch_size=[config['batch_size_train'],config['batch_size_test']],
122
+ num_workers=[4,4],is_trains=[True, False],
123
+ collate_fns=[vqa_collate_fn,None])
124
+ #### Model ####
125
+ print("Creating model")
126
+ model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
127
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
128
+
129
+ model = model.to(device)
130
+
131
+ model_without_ddp = model
132
+ if args.distributed:
133
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
134
+ model_without_ddp = model.module
135
+
136
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
137
+
138
+ best = 0
139
+ best_epoch = 0
140
+
141
+ print("Start training")
142
+ start_time = time.time()
143
+ for epoch in range(0, config['max_epoch']):
144
+ if not args.evaluate:
145
+ if args.distributed:
146
+ train_loader.sampler.set_epoch(epoch)
147
+
148
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
149
+
150
+ train_stats = train(model, train_loader, optimizer, epoch, device)
151
+
152
+ else:
153
+ break
154
+
155
+ if utils.is_main_process():
156
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
157
+ 'epoch': epoch,
158
+ }
159
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
160
+ f.write(json.dumps(log_stats) + "\n")
161
+
162
+ save_obj = {
163
+ 'model': model_without_ddp.state_dict(),
164
+ 'optimizer': optimizer.state_dict(),
165
+ 'config': config,
166
+ 'epoch': epoch,
167
+ }
168
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
169
+
170
+ dist.barrier()
171
+
172
+ vqa_result = evaluation(model_without_ddp, test_loader, device, config)
173
+ result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
174
+
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print('Training time {}'.format(total_time_str))
178
+
179
+
180
+
181
+ if __name__ == '__main__':
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument('--config', default='./configs/vqa.yaml')
184
+ parser.add_argument('--output_dir', default='output/VQA')
185
+ parser.add_argument('--evaluate', action='store_true')
186
+ parser.add_argument('--device', default='cuda')
187
+ parser.add_argument('--seed', default=42, type=int)
188
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
189
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
190
+ parser.add_argument('--distributed', default=True, type=bool)
191
+ args = parser.parse_args()
192
+
193
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
194
+
195
+ args.result_dir = os.path.join(args.output_dir, 'result')
196
+
197
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
198
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
199
+
200
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
201
+
202
+ main(args, config)
utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
3
+ """Decay the learning rate"""
4
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
5
+ for param_group in optimizer.param_groups:
6
+ param_group['lr'] = lr
7
+
8
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
9
+ """Warmup the learning rate"""
10
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
11
+ for param_group in optimizer.param_groups:
12
+ param_group['lr'] = lr
13
+
14
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
15
+ """Decay the learning rate"""
16
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
17
+ for param_group in optimizer.param_groups:
18
+ param_group['lr'] = lr
19
+
20
+ import numpy as np
21
+ import io
22
+ import os
23
+ import time
24
+ from collections import defaultdict, deque
25
+ import datetime
26
+
27
+ import torch
28
+ import torch.distributed as dist
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if isinstance(v, torch.Tensor):
100
+ v = v.item()
101
+ assert isinstance(v, (float, int))
102
+ self.meters[k].update(v)
103
+
104
+ def __getattr__(self, attr):
105
+ if attr in self.meters:
106
+ return self.meters[attr]
107
+ if attr in self.__dict__:
108
+ return self.__dict__[attr]
109
+ raise AttributeError("'{}' object has no attribute '{}'".format(
110
+ type(self).__name__, attr))
111
+
112
+ def __str__(self):
113
+ loss_str = []
114
+ for name, meter in self.meters.items():
115
+ loss_str.append(
116
+ "{}: {}".format(name, str(meter))
117
+ )
118
+ return self.delimiter.join(loss_str)
119
+
120
+ def global_avg(self):
121
+ loss_str = []
122
+ for name, meter in self.meters.items():
123
+ loss_str.append(
124
+ "{}: {:.4f}".format(name, meter.global_avg)
125
+ )
126
+ return self.delimiter.join(loss_str)
127
+
128
+ def synchronize_between_processes(self):
129
+ for meter in self.meters.values():
130
+ meter.synchronize_between_processes()
131
+
132
+ def add_meter(self, name, meter):
133
+ self.meters[name] = meter
134
+
135
+ def log_every(self, iterable, print_freq, header=None):
136
+ i = 0
137
+ if not header:
138
+ header = ''
139
+ start_time = time.time()
140
+ end = time.time()
141
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
142
+ data_time = SmoothedValue(fmt='{avg:.4f}')
143
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
144
+ log_msg = [
145
+ header,
146
+ '[{0' + space_fmt + '}/{1}]',
147
+ 'eta: {eta}',
148
+ '{meters}',
149
+ 'time: {time}',
150
+ 'data: {data}'
151
+ ]
152
+ if torch.cuda.is_available():
153
+ log_msg.append('max mem: {memory:.0f}')
154
+ log_msg = self.delimiter.join(log_msg)
155
+ MB = 1024.0 * 1024.0
156
+ for obj in iterable:
157
+ data_time.update(time.time() - end)
158
+ yield obj
159
+ iter_time.update(time.time() - end)
160
+ if i % print_freq == 0 or i == len(iterable) - 1:
161
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
162
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
163
+ if torch.cuda.is_available():
164
+ print(log_msg.format(
165
+ i, len(iterable), eta=eta_string,
166
+ meters=str(self),
167
+ time=str(iter_time), data=str(data_time),
168
+ memory=torch.cuda.max_memory_allocated() / MB))
169
+ else:
170
+ print(log_msg.format(
171
+ i, len(iterable), eta=eta_string,
172
+ meters=str(self),
173
+ time=str(iter_time), data=str(data_time)))
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print('{} Total time: {} ({:.4f} s / it)'.format(
179
+ header, total_time_str, total_time / len(iterable)))
180
+
181
+
182
+ class AttrDict(dict):
183
+ def __init__(self, *args, **kwargs):
184
+ super(AttrDict, self).__init__(*args, **kwargs)
185
+ self.__dict__ = self
186
+
187
+
188
+ def compute_acc(logits, label, reduction='mean'):
189
+ ret = (torch.argmax(logits, dim=1) == label).float()
190
+ if reduction == 'none':
191
+ return ret.detach()
192
+ elif reduction == 'mean':
193
+ return ret.mean().item()
194
+
195
+ def compute_n_params(model, return_str=True):
196
+ tot = 0
197
+ for p in model.parameters():
198
+ w = 1
199
+ for x in p.shape:
200
+ w *= x
201
+ tot += w
202
+ if return_str:
203
+ if tot >= 1e6:
204
+ return '{:.1f}M'.format(tot / 1e6)
205
+ else:
206
+ return '{:.1f}K'.format(tot / 1e3)
207
+ else:
208
+ return tot
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ import builtins as __builtin__
215
+ builtin_print = __builtin__.print
216
+
217
+ def print(*args, **kwargs):
218
+ force = kwargs.pop('force', False)
219
+ if is_master or force:
220
+ builtin_print(*args, **kwargs)
221
+
222
+ __builtin__.print = print
223
+
224
+
225
+ def is_dist_avail_and_initialized():
226
+ if not dist.is_available():
227
+ return False
228
+ if not dist.is_initialized():
229
+ return False
230
+ return True
231
+
232
+
233
+ def get_world_size():
234
+ if not is_dist_avail_and_initialized():
235
+ return 1
236
+ return dist.get_world_size()
237
+
238
+
239
+ def get_rank():
240
+ if not is_dist_avail_and_initialized():
241
+ return 0
242
+ return dist.get_rank()
243
+
244
+
245
+ def is_main_process():
246
+ return get_rank() == 0
247
+
248
+
249
+ def save_on_master(*args, **kwargs):
250
+ if is_main_process():
251
+ torch.save(*args, **kwargs)
252
+
253
+
254
+ def init_distributed_mode(args):
255
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
256
+ args.rank = int(os.environ["RANK"])
257
+ args.world_size = int(os.environ['WORLD_SIZE'])
258
+ args.gpu = int(os.environ['LOCAL_RANK'])
259
+ elif 'SLURM_PROCID' in os.environ:
260
+ args.rank = int(os.environ['SLURM_PROCID'])
261
+ args.gpu = args.rank % torch.cuda.device_count()
262
+ else:
263
+ print('Not using distributed mode')
264
+ args.distributed = False
265
+ return
266
+
267
+ args.distributed = True
268
+
269
+ torch.cuda.set_device(args.gpu)
270
+ args.dist_backend = 'nccl'
271
+ print('| distributed init (rank {}, word {}): {}'.format(
272
+ args.rank, args.world_size, args.dist_url), flush=True)
273
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274
+ world_size=args.world_size, rank=args.rank)
275
+ torch.distributed.barrier()
276
+ setup_for_distributed(args.rank == 0)
277
+
278
+