diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea --- /dev/null +++ b/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1e07fc30dd1f42bfad1e17dcdb4dafd207cb174 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: Generate Human Motion +emoji: 🏃 +colorFrom: green +colorTo: yellow +sdk: gradio +sdk_version: 3.16.2 +app_file: app.py +pinned: false +license: apache-2.0 +duplicated_from: vumichien/generate_human_motion +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/VQ-Trans/.gitignore b/VQ-Trans/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..570180825f02f1a7e7bd61d9a035efd4bb65dba2 --- /dev/null +++ b/VQ-Trans/.gitignore @@ -0,0 +1,70 @@ +# C extensions +*.so + + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +.vscode \ No newline at end of file diff --git a/VQ-Trans/GPT_eval_multi.py b/VQ-Trans/GPT_eval_multi.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e3ebcb1199e42cf16748e60863b554a0046f00 --- /dev/null +++ b/VQ-Trans/GPT_eval_multi.py @@ -0,0 +1,121 @@ +import os +import torch +import numpy as np +from torch.utils.tensorboard import SummaryWriter +import json +import clip + +import options.option_transformer as option_trans +import models.vqvae as vqvae +import utils.utils_model as utils_model +import utils.eval_trans as eval_trans +from dataset import dataset_TM_eval +import models.t2m_trans as trans +from options.get_eval_option import get_opt +from models.evaluator_wrapper import EvaluatorModelWrapper +import warnings +warnings.filterwarnings('ignore') + +##### ---- Exp dirs ---- ##### +args = option_trans.get_args_parser() +torch.manual_seed(args.seed) + +args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') +os.makedirs(args.out_dir, exist_ok = True) + +##### ---- Logger ---- ##### +logger = utils_model.get_logger(args.out_dir) +writer = SummaryWriter(args.out_dir) +logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) + +from utils.word_vectorizer import WordVectorizer +w_vectorizer = WordVectorizer('./glove', 'our_vab') +val_loader = dataset_TM_eval.DATALoader(args.dataname, True, 32, w_vectorizer) + +dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' + +wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) +eval_wrapper = EvaluatorModelWrapper(wrapper_opt) + +##### ---- Network ---- ##### + +## load clip model and datasets +clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False, download_root='/apdcephfs_cq2/share_1290939/maelyszhang/.cache/clip') # Must set jit=False for training +clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16 +clip_model.eval() +for p in clip_model.parameters(): + p.requires_grad = False + +net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers + args.nb_code, + args.code_dim, + args.output_emb_width, + args.down_t, + args.stride_t, + args.width, + args.depth, + args.dilation_growth_rate) + + +trans_encoder = trans.Text2Motion_Transformer(num_vq=args.nb_code, + embed_dim=args.embed_dim_gpt, + clip_dim=args.clip_dim, + block_size=args.block_size, + num_layers=args.num_layers, + n_head=args.n_head_gpt, + drop_out_rate=args.drop_out_rate, + fc_rate=args.ff_rate) + + +print ('loading checkpoint from {}'.format(args.resume_pth)) +ckpt = torch.load(args.resume_pth, map_location='cpu') +net.load_state_dict(ckpt['net'], strict=True) +net.eval() +net.cuda() + +if args.resume_trans is not None: + print ('loading transformer checkpoint from {}'.format(args.resume_trans)) + ckpt = torch.load(args.resume_trans, map_location='cpu') + trans_encoder.load_state_dict(ckpt['trans'], strict=True) +trans_encoder.train() +trans_encoder.cuda() + + +fid = [] +div = [] +top1 = [] +top2 = [] +top3 = [] +matching = [] +multi = [] +repeat_time = 20 + + +for i in range(repeat_time): + best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, writer, logger = eval_trans.evaluation_transformer_test(args.out_dir, val_loader, net, trans_encoder, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, best_multi=0, clip_model=clip_model, eval_wrapper=eval_wrapper, draw=False, savegif=False, save=False, savenpy=(i==0)) + fid.append(best_fid) + div.append(best_div) + top1.append(best_top1) + top2.append(best_top2) + top3.append(best_top3) + matching.append(best_matching) + multi.append(best_multi) + +print('final result:') +print('fid: ', sum(fid)/repeat_time) +print('div: ', sum(div)/repeat_time) +print('top1: ', sum(top1)/repeat_time) +print('top2: ', sum(top2)/repeat_time) +print('top3: ', sum(top3)/repeat_time) +print('matching: ', sum(matching)/repeat_time) +print('multi: ', sum(multi)/repeat_time) + +fid = np.array(fid) +div = np.array(div) +top1 = np.array(top1) +top2 = np.array(top2) +top3 = np.array(top3) +matching = np.array(matching) +multi = np.array(multi) +msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}, Multi. {np.mean(multi):.3f}, conf. {np.std(multi)*1.96/np.sqrt(repeat_time):.3f}" +logger.info(msg_final) \ No newline at end of file diff --git a/VQ-Trans/README.md b/VQ-Trans/README.md new file mode 100644 index 0000000000000000000000000000000000000000..547a1d4b52a5c76f0f86c641557f99d0688c0ffd --- /dev/null +++ b/VQ-Trans/README.md @@ -0,0 +1,400 @@ +# Motion VQ-Trans +Pytorch implementation of paper "Generating Human Motion from Textual Descriptions with High Quality Discrete Representation" + + +[[Notebook Demo]](https://colab.research.google.com/drive/1tAHlmcpKcjg_zZrqKku7AfpqdVAIFrF8?usp=sharing) + + +![teaser](img/Teaser.png) + +If our project is helpful for your research, please consider citing : (todo) +``` +@inproceedings{shen2020ransac, + title={RANSAC-Flow: generic two-stage image alignment}, + author={Shen, Xi and Darmon, Fran{\c{c}}ois and Efros, Alexei A and Aubry, Mathieu}, + booktitle={16th European Conference on Computer Vision} + year={2020} + } +``` + + +## Table of Content +* [1. Visual Results](#1-visual-results) +* [2. Installation](#2-installation) +* [3. Quick Start](#3-quick-start) +* [4. Train](#4-train) +* [5. Evaluation](#5-evaluation) +* [6. Motion Render](#6-motion-render) +* [7. Acknowledgement](#7-acknowledgement) +* [8. ChangLog](#8-changlog) + + + + +## 1. Visual Results (More results can be found in our project page (todo)) + +![visualization](img/ALLvis.png) + + +## 2. Installation + +### 2.1. Environment + + + +Our model can be learnt in a **single GPU V100-32G** + +```bash +conda env create -f environment.yml +conda activate VQTrans +``` + +The code was tested on Python 3.8 and PyTorch 1.8.1. + + +### 2.2. Dependencies + +```bash +bash dataset/prepare/download_glove.sh +``` + + +### 2.3. Datasets + + +We are using two 3D human motion-language dataset: HumanML3D and KIT-ML. For both datasets, you could find the details as well as download link [[here]](https://github.com/EricGuo5513/HumanML3D). + +Take HumanML3D for an example, the file directory should look like this: +``` +./dataset/HumanML3D/ +├── new_joint_vecs/ +├── texts/ +├── Mean.npy # same as in [HumanML3D](https://github.com/EricGuo5513/HumanML3D) +├── Std.npy # same as in [HumanML3D](https://github.com/EricGuo5513/HumanML3D) +├── train.txt +├── val.txt +├── test.txt +├── train_val.txt +└──all.txt +``` + + +### 2.4. Motion & text feature extractors: + +We use the same extractors provided by [t2m](https://github.com/EricGuo5513/text-to-motion) to evaluate our generated motions. Please download the extractors. + +```bash +bash dataset/prepare/download_extractor.sh +``` + +### 2.5. Pre-trained models + +The pretrained model files will be stored in the 'pretrained' folder: +```bash +bash dataset/prepare/download_model.sh +``` + + + +### 2.6. Render motion (optional) + +If you want to render the generated motion, you need to install: + +```bash +sudo sh dataset/prepare/download_smpl.sh +conda install -c menpo osmesa +conda install h5py +conda install -c conda-forge shapely pyrender trimesh mapbox_earcut +``` + + + +## 3. Quick Start + +A quick start guide of how to use our code is available in [demo.ipynb](https://colab.research.google.com/drive/1tAHlmcpKcjg_zZrqKku7AfpqdVAIFrF8?usp=sharing) + +

+demo +

+ + +## 4. Train + +Note that, for kit dataset, just need to set '--dataname kit'. + +### 4.1. VQ-VAE + +The results are saved in the folder output_vqfinal. + +
+ +VQ training + + +```bash +python3 train_vq.py \ +--batch-size 256 \ +--lr 2e-4 \ +--total-iter 300000 \ +--lr-scheduler 200000 \ +--nb-code 512 \ +--down-t 2 \ +--depth 3 \ +--dilation-growth-rate 3 \ +--out-dir output \ +--dataname t2m \ +--vq-act relu \ +--quantizer ema_reset \ +--loss-vel 0.5 \ +--recons-loss l1_smooth \ +--exp-name VQVAE +``` + +
+ +### 4.2. Motion-Transformer + +The results are saved in the folder output_transformer. + +
+ +MoTrans training + + +```bash +python3 train_t2m_trans.py \ +--exp-name VQTransformer \ +--batch-size 128 \ +--num-layers 9 \ +--embed-dim-gpt 1024 \ +--nb-code 512 \ +--n-head-gpt 16 \ +--block-size 51 \ +--ff-rate 4 \ +--drop-out-rate 0.1 \ +--resume-pth output/VQVAE/net_last.pth \ +--vq-name VQVAE \ +--out-dir output \ +--total-iter 300000 \ +--lr-scheduler 150000 \ +--lr 0.0001 \ +--dataname t2m \ +--down-t 2 \ +--depth 3 \ +--quantizer ema_reset \ +--eval-iter 10000 \ +--pkeep 0.5 \ +--dilation-growth-rate 3 \ +--vq-act relu +``` + +
+ +## 5. Evaluation + +### 5.1. VQ-VAE +
+ +VQ eval + + +```bash +python3 VQ_eval.py \ +--batch-size 256 \ +--lr 2e-4 \ +--total-iter 300000 \ +--lr-scheduler 200000 \ +--nb-code 512 \ +--down-t 2 \ +--depth 3 \ +--dilation-growth-rate 3 \ +--out-dir output \ +--dataname t2m \ +--vq-act relu \ +--quantizer ema_reset \ +--loss-vel 0.5 \ +--recons-loss l1_smooth \ +--exp-name TEST_VQVAE \ +--resume-pth output/VQVAE/net_last.pth +``` + +
+ +### 5.2. Motion-Transformer + +
+ +MoTrans eval + + +```bash +python3 GPT_eval_multi.py \ +--exp-name TEST_VQTransformer \ +--batch-size 128 \ +--num-layers 9 \ +--embed-dim-gpt 1024 \ +--nb-code 512 \ +--n-head-gpt 16 \ +--block-size 51 \ +--ff-rate 4 \ +--drop-out-rate 0.1 \ +--resume-pth output/VQVAE/net_last.pth \ +--vq-name VQVAE \ +--out-dir output \ +--total-iter 300000 \ +--lr-scheduler 150000 \ +--lr 0.0001 \ +--dataname t2m \ +--down-t 2 \ +--depth 3 \ +--quantizer ema_reset \ +--eval-iter 10000 \ +--pkeep 0.5 \ +--dilation-growth-rate 3 \ +--vq-act relu \ +--resume-gpt output/VQTransformer/net_best_fid.pth +``` + +
+ + +## 6. Motion Render + +
+ +Motion Render + + +You should input the npy folder address and the motion names. Here is an example: + +```bash +python3 render_final.py --filedir output/TEST_VQTransformer/ --motion-list 000019 005485 +``` + +
+ +### 7. Acknowledgement + +We appreciate helps from : + +* Public code like [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [TM2T](https://github.com/EricGuo5513/TM2T) etc. + +### 8. ChangLog + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/VQ-Trans/VQ_eval.py b/VQ-Trans/VQ_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b7f269e344f730797eba13a45c9672f323b9f5 --- /dev/null +++ b/VQ-Trans/VQ_eval.py @@ -0,0 +1,95 @@ +import os +import json + +import torch +from torch.utils.tensorboard import SummaryWriter +import numpy as np +import models.vqvae as vqvae +import options.option_vq as option_vq +import utils.utils_model as utils_model +from dataset import dataset_TM_eval +import utils.eval_trans as eval_trans +from options.get_eval_option import get_opt +from models.evaluator_wrapper import EvaluatorModelWrapper +import warnings +warnings.filterwarnings('ignore') +import numpy as np +##### ---- Exp dirs ---- ##### +args = option_vq.get_args_parser() +torch.manual_seed(args.seed) + +args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') +os.makedirs(args.out_dir, exist_ok = True) + +##### ---- Logger ---- ##### +logger = utils_model.get_logger(args.out_dir) +writer = SummaryWriter(args.out_dir) +logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) + + +from utils.word_vectorizer import WordVectorizer +w_vectorizer = WordVectorizer('./glove', 'our_vab') + + +dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' + +wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) +eval_wrapper = EvaluatorModelWrapper(wrapper_opt) + + +##### ---- Dataloader ---- ##### +args.nb_joints = 21 if args.dataname == 'kit' else 22 + +val_loader = dataset_TM_eval.DATALoader(args.dataname, True, 32, w_vectorizer, unit_length=2**args.down_t) + +##### ---- Network ---- ##### +net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers + args.nb_code, + args.code_dim, + args.output_emb_width, + args.down_t, + args.stride_t, + args.width, + args.depth, + args.dilation_growth_rate, + args.vq_act, + args.vq_norm) + +if args.resume_pth : + logger.info('loading checkpoint from {}'.format(args.resume_pth)) + ckpt = torch.load(args.resume_pth, map_location='cpu') + net.load_state_dict(ckpt['net'], strict=True) +net.train() +net.cuda() + +fid = [] +div = [] +top1 = [] +top2 = [] +top3 = [] +matching = [] +repeat_time = 20 +for i in range(repeat_time): + best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper, draw=False, save=False, savenpy=(i==0)) + fid.append(best_fid) + div.append(best_div) + top1.append(best_top1) + top2.append(best_top2) + top3.append(best_top3) + matching.append(best_matching) +print('final result:') +print('fid: ', sum(fid)/repeat_time) +print('div: ', sum(div)/repeat_time) +print('top1: ', sum(top1)/repeat_time) +print('top2: ', sum(top2)/repeat_time) +print('top3: ', sum(top3)/repeat_time) +print('matching: ', sum(matching)/repeat_time) + +fid = np.array(fid) +div = np.array(div) +top1 = np.array(top1) +top2 = np.array(top2) +top3 = np.array(top3) +matching = np.array(matching) +msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}" +logger.info(msg_final) \ No newline at end of file diff --git a/VQ-Trans/ViT-B-32.pt b/VQ-Trans/ViT-B-32.pt new file mode 100644 index 0000000000000000000000000000000000000000..06a4dea8587eb4948a3221b1e1b2e2475e0e203b --- /dev/null +++ b/VQ-Trans/ViT-B-32.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af +size 353976522 diff --git a/VQ-Trans/body_models/smpl/J_regressor_extra.npy b/VQ-Trans/body_models/smpl/J_regressor_extra.npy new file mode 100755 index 0000000000000000000000000000000000000000..d6cf8c0f6747d3c623a0d300c5176843ae99031d --- /dev/null +++ b/VQ-Trans/body_models/smpl/J_regressor_extra.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc968ea4f9855571e82f90203280836b01f13ee42a8e1b89d8d580b801242a89 +size 496160 diff --git a/VQ-Trans/body_models/smpl/SMPL_NEUTRAL.pkl b/VQ-Trans/body_models/smpl/SMPL_NEUTRAL.pkl new file mode 100644 index 0000000000000000000000000000000000000000..26574fd104c4b69467f3c7c3516a8508d8a1a36e --- /dev/null +++ b/VQ-Trans/body_models/smpl/SMPL_NEUTRAL.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98e65c74ad9b998783132f00880d1025a8d64b158e040e6ef13a557e5098bc42 +size 39001280 diff --git a/VQ-Trans/body_models/smpl/kintree_table.pkl b/VQ-Trans/body_models/smpl/kintree_table.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3f72aca47e9257f017ab09470ee977a33f41a49e --- /dev/null +++ b/VQ-Trans/body_models/smpl/kintree_table.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62116ec76c6192ae912557122ea935267ba7188144efb9306ea1366f0e50d4d2 +size 349 diff --git a/VQ-Trans/body_models/smpl/smplfaces.npy b/VQ-Trans/body_models/smpl/smplfaces.npy new file mode 100644 index 0000000000000000000000000000000000000000..6cadb74c9df2b6deebcdc90ee4f8cf9efbffb11d --- /dev/null +++ b/VQ-Trans/body_models/smpl/smplfaces.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ee8e99db736acf178a6078ab5710ca942edc3738d34c72f41a35c40b370e045 +size 165440 diff --git a/VQ-Trans/checkpoints/kit.zip b/VQ-Trans/checkpoints/kit.zip new file mode 100644 index 0000000000000000000000000000000000000000..f0ed05d6a11c4a4e8337072287c2ac787793c8a0 --- /dev/null +++ b/VQ-Trans/checkpoints/kit.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e9d54e1c68bacad61277f89c7d05f9c88a68fd92ff79f79644128bb9b2508cb +size 704518254 diff --git a/VQ-Trans/checkpoints/t2m.zip b/VQ-Trans/checkpoints/t2m.zip new file mode 100644 index 0000000000000000000000000000000000000000..43d6be4f2fb100d8696d2ab9aef463cc2aab7bb5 --- /dev/null +++ b/VQ-Trans/checkpoints/t2m.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09e0628dbc585416217617c0583415c8f654ff855703d72fdb713f7061c0863e +size 1222422692 diff --git a/VQ-Trans/checkpoints/train_vq.py b/VQ-Trans/checkpoints/train_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..d89b9930ba1262747542df3d5b2f03f8fab1b04a --- /dev/null +++ b/VQ-Trans/checkpoints/train_vq.py @@ -0,0 +1,171 @@ +import os +import json + +import torch +import torch.optim as optim +from torch.utils.tensorboard import SummaryWriter + +import models.vqvae as vqvae +import utils.losses as losses +import options.option_vq as option_vq +import utils.utils_model as utils_model +from dataset import dataset_VQ, dataset_TM_eval +import utils.eval_trans as eval_trans +from options.get_eval_option import get_opt +from models.evaluator_wrapper import EvaluatorModelWrapper +import warnings +warnings.filterwarnings('ignore') +from utils.word_vectorizer import WordVectorizer + +def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr): + + current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) + for param_group in optimizer.param_groups: + param_group["lr"] = current_lr + + return optimizer, current_lr + +##### ---- Exp dirs ---- ##### +args = option_vq.get_args_parser() +torch.manual_seed(args.seed) + +args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') +os.makedirs(args.out_dir, exist_ok = True) + +##### ---- Logger ---- ##### +logger = utils_model.get_logger(args.out_dir) +writer = SummaryWriter(args.out_dir) +logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) + + + +w_vectorizer = WordVectorizer('./glove', 'our_vab') + +if args.dataname == 'kit' : + dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' + args.nb_joints = 21 + +else : + dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' + args.nb_joints = 22 + +logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints') + +wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) +eval_wrapper = EvaluatorModelWrapper(wrapper_opt) + + +##### ---- Dataloader ---- ##### +train_loader = dataset_VQ.DATALoader(args.dataname, + args.batch_size, + window_size=args.window_size, + unit_length=2**args.down_t) + +train_loader_iter = dataset_VQ.cycle(train_loader) + +val_loader = dataset_TM_eval.DATALoader(args.dataname, False, + 32, + w_vectorizer, + unit_length=2**args.down_t) + +##### ---- Network ---- ##### +net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers + args.nb_code, + args.code_dim, + args.output_emb_width, + args.down_t, + args.stride_t, + args.width, + args.depth, + args.dilation_growth_rate, + args.vq_act, + args.vq_norm) + + +if args.resume_pth : + logger.info('loading checkpoint from {}'.format(args.resume_pth)) + ckpt = torch.load(args.resume_pth, map_location='cpu') + net.load_state_dict(ckpt['net'], strict=True) +net.train() +net.cuda() + +##### ---- Optimizer & Scheduler ---- ##### +optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma) + + +Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints) + +##### ------ warm-up ------- ##### +avg_recons, avg_perplexity, avg_commit = 0., 0., 0. + +for nb_iter in range(1, args.warm_up_iter): + + optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr) + + gt_motion = next(train_loader_iter) + gt_motion = gt_motion.cuda().float() # (bs, 64, dim) + + pred_motion, loss_commit, perplexity = net(gt_motion) + loss_motion = Loss(pred_motion, gt_motion) + loss_vel = Loss.forward_vel(pred_motion, gt_motion) + + loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + avg_recons += loss_motion.item() + avg_perplexity += perplexity.item() + avg_commit += loss_commit.item() + + if nb_iter % args.print_iter == 0 : + avg_recons /= args.print_iter + avg_perplexity /= args.print_iter + avg_commit /= args.print_iter + + logger.info(f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") + + avg_recons, avg_perplexity, avg_commit = 0., 0., 0. + +##### ---- Training ---- ##### +avg_recons, avg_perplexity, avg_commit = 0., 0., 0. +best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper) + +for nb_iter in range(1, args.total_iter + 1): + + gt_motion = next(train_loader_iter) + gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len + + pred_motion, loss_commit, perplexity = net(gt_motion) + loss_motion = Loss(pred_motion, gt_motion) + loss_vel = Loss.forward_vel(pred_motion, gt_motion) + + loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel + + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + avg_recons += loss_motion.item() + avg_perplexity += perplexity.item() + avg_commit += loss_commit.item() + + if nb_iter % args.print_iter == 0 : + avg_recons /= args.print_iter + avg_perplexity /= args.print_iter + avg_commit /= args.print_iter + + writer.add_scalar('./Train/L1', avg_recons, nb_iter) + writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter) + writer.add_scalar('./Train/Commit', avg_commit, nb_iter) + + logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") + + avg_recons, avg_perplexity, avg_commit = 0., 0., 0., + + if nb_iter % args.eval_iter==0 : + best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper) + \ No newline at end of file diff --git a/VQ-Trans/dataset/dataset_TM_eval.py b/VQ-Trans/dataset/dataset_TM_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..576a53b7dabd8095bed59dcc86199e30f2798838 --- /dev/null +++ b/VQ-Trans/dataset/dataset_TM_eval.py @@ -0,0 +1,217 @@ +import torch +from torch.utils import data +import numpy as np +from os.path import join as pjoin +import random +import codecs as cs +from tqdm import tqdm + +import utils.paramUtil as paramUtil +from torch.utils.data._utils.collate import default_collate + + +def collate_fn(batch): + batch.sort(key=lambda x: x[3], reverse=True) + return default_collate(batch) + + +'''For use of training text-2-motion generative model''' +class Text2MotionDataset(data.Dataset): + def __init__(self, dataset_name, is_test, w_vectorizer, feat_bias = 5, max_text_len = 20, unit_length = 4): + + self.max_length = 20 + self.pointer = 0 + self.dataset_name = dataset_name + self.is_test = is_test + self.max_text_len = max_text_len + self.unit_length = unit_length + self.w_vectorizer = w_vectorizer + if dataset_name == 't2m': + self.data_root = './dataset/HumanML3D' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 22 + radius = 4 + fps = 20 + self.max_motion_length = 196 + dim_pose = 263 + kinematic_chain = paramUtil.t2m_kinematic_chain + self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' + elif dataset_name == 'kit': + self.data_root = './dataset/KIT-ML' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 21 + radius = 240 * 8 + fps = 12.5 + dim_pose = 251 + self.max_motion_length = 196 + kinematic_chain = paramUtil.kit_kinematic_chain + self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' + + mean = np.load(pjoin(self.meta_dir, 'mean.npy')) + std = np.load(pjoin(self.meta_dir, 'std.npy')) + + if is_test: + split_file = pjoin(self.data_root, 'test.txt') + else: + split_file = pjoin(self.data_root, 'val.txt') + + min_motion_len = 40 if self.dataset_name =='t2m' else 24 + # min_motion_len = 64 + + joints_num = self.joints_num + + data_dict = {} + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + + new_name_list = [] + length_list = [] + for name in tqdm(id_list): + try: + motion = np.load(pjoin(self.motion_dir, name + '.npy')) + if (len(motion)) < min_motion_len or (len(motion) >= 200): + continue + text_data = [] + flag = False + with cs.open(pjoin(self.text_dir, name + '.txt')) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag*fps) : int(to_tag*fps)] + if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): + continue + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + while new_name in data_dict: + new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name + data_dict[new_name] = {'motion': n_motion, + 'length': len(n_motion), + 'text':[text_dict]} + new_name_list.append(new_name) + length_list.append(len(n_motion)) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + # break + + if flag: + data_dict[name] = {'motion': motion, + 'length': len(motion), + 'text': text_data} + new_name_list.append(name) + length_list.append(len(motion)) + except Exception as e: + # print(e) + pass + + name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + self.mean = mean + self.std = std + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = name_list + self.reset_max_len(self.max_length) + + def reset_max_len(self, length): + assert length <= self.max_motion_length + self.pointer = np.searchsorted(self.length_arr, length) + print("Pointer Pointing at %d"%self.pointer) + self.max_length = length + + def inv_transform(self, data): + return data * self.std + self.mean + + def forward_transform(self, data): + return (data - self.mean) / self.std + + def __len__(self): + return len(self.data_dict) - self.pointer + + def __getitem__(self, item): + idx = self.pointer + item + name = self.name_list[idx] + data = self.data_dict[name] + # data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data['motion'], data['length'], data['text'] + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data['caption'], text_data['tokens'] + + if len(tokens) < self.max_text_len: + # pad with "unk" + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:self.max_text_len] + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + if self.unit_length < 10: + coin2 = np.random.choice(['single', 'single', 'double']) + else: + coin2 = 'single' + + if coin2 == 'double': + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == 'single': + m_length = (m_length // self.unit_length) * self.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx+m_length] + + "Z Normalization" + motion = (motion - self.mean) / self.std + + if m_length < self.max_motion_length: + motion = np.concatenate([motion, + np.zeros((self.max_motion_length - m_length, motion.shape[1])) + ], axis=0) + + return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens), name + + + + +def DATALoader(dataset_name, is_test, + batch_size, w_vectorizer, + num_workers = 8, unit_length = 4) : + + val_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, is_test, w_vectorizer, unit_length=unit_length), + batch_size, + shuffle = True, + num_workers=num_workers, + collate_fn=collate_fn, + drop_last = True) + return val_loader + + +def cycle(iterable): + while True: + for x in iterable: + yield x diff --git a/VQ-Trans/dataset/dataset_TM_train.py b/VQ-Trans/dataset/dataset_TM_train.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0223effb01c1cf57fa6b2b6fb8d9d01b83f84a --- /dev/null +++ b/VQ-Trans/dataset/dataset_TM_train.py @@ -0,0 +1,161 @@ +import torch +from torch.utils import data +import numpy as np +from os.path import join as pjoin +import random +import codecs as cs +from tqdm import tqdm +import utils.paramUtil as paramUtil +from torch.utils.data._utils.collate import default_collate + + +def collate_fn(batch): + batch.sort(key=lambda x: x[3], reverse=True) + return default_collate(batch) + + +'''For use of training text-2-motion generative model''' +class Text2MotionDataset(data.Dataset): + def __init__(self, dataset_name, feat_bias = 5, unit_length = 4, codebook_size = 1024, tokenizer_name=None): + + self.max_length = 64 + self.pointer = 0 + self.dataset_name = dataset_name + + self.unit_length = unit_length + # self.mot_start_idx = codebook_size + self.mot_end_idx = codebook_size + self.mot_pad_idx = codebook_size + 1 + if dataset_name == 't2m': + self.data_root = './dataset/HumanML3D' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 22 + radius = 4 + fps = 20 + self.max_motion_length = 26 if unit_length == 8 else 51 + dim_pose = 263 + kinematic_chain = paramUtil.t2m_kinematic_chain + elif dataset_name == 'kit': + self.data_root = './dataset/KIT-ML' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 21 + radius = 240 * 8 + fps = 12.5 + dim_pose = 251 + self.max_motion_length = 26 if unit_length == 8 else 51 + kinematic_chain = paramUtil.kit_kinematic_chain + + split_file = pjoin(self.data_root, 'train.txt') + + + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + + new_name_list = [] + data_dict = {} + for name in tqdm(id_list): + try: + m_token_list = np.load(pjoin(self.data_root, tokenizer_name, '%s.npy'%name)) + + # Read text + with cs.open(pjoin(self.text_dir, name + '.txt')) as f: + text_data = [] + flag = False + lines = f.readlines() + + for line in lines: + try: + text_dict = {} + line_split = line.strip().split('#') + caption = line_split[0] + t_tokens = line_split[1].split(' ') + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict['caption'] = caption + text_dict['tokens'] = t_tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + m_token_list_new = [tokens[int(f_tag*fps/unit_length) : int(to_tag*fps/unit_length)] for tokens in m_token_list if int(f_tag*fps/unit_length) < int(to_tag*fps/unit_length)] + + if len(m_token_list_new) == 0: + continue + new_name = '%s_%f_%f'%(name, f_tag, to_tag) + + data_dict[new_name] = {'m_token_list': m_token_list_new, + 'text':[text_dict]} + new_name_list.append(new_name) + except: + pass + + if flag: + data_dict[name] = {'m_token_list': m_token_list, + 'text':text_data} + new_name_list.append(name) + except: + pass + self.data_dict = data_dict + self.name_list = new_name_list + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, item): + data = self.data_dict[self.name_list[item]] + m_token_list, text_list = data['m_token_list'], data['text'] + m_tokens = random.choice(m_token_list) + + text_data = random.choice(text_list) + caption= text_data['caption'] + + + coin = np.random.choice([False, False, True]) + # print(len(m_tokens)) + if coin: + # drop one token at the head or tail + coin2 = np.random.choice([True, False]) + if coin2: + m_tokens = m_tokens[:-1] + else: + m_tokens = m_tokens[1:] + m_tokens_len = m_tokens.shape[0] + + if m_tokens_len+1 < self.max_motion_length: + m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx, np.ones((self.max_motion_length-1-m_tokens_len), dtype=int) * self.mot_pad_idx], axis=0) + else: + m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx], axis=0) + + return caption, m_tokens.reshape(-1), m_tokens_len + + + + +def DATALoader(dataset_name, + batch_size, codebook_size, tokenizer_name, unit_length=4, + num_workers = 8) : + + train_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, codebook_size = codebook_size, tokenizer_name = tokenizer_name, unit_length=unit_length), + batch_size, + shuffle=True, + num_workers=num_workers, + #collate_fn=collate_fn, + drop_last = True) + + + return train_loader + + +def cycle(iterable): + while True: + for x in iterable: + yield x + + diff --git a/VQ-Trans/dataset/dataset_VQ.py b/VQ-Trans/dataset/dataset_VQ.py new file mode 100644 index 0000000000000000000000000000000000000000..2342de946f2cbdf64729a5145168df1bdda54fa0 --- /dev/null +++ b/VQ-Trans/dataset/dataset_VQ.py @@ -0,0 +1,109 @@ +import torch +from torch.utils import data +import numpy as np +from os.path import join as pjoin +import random +import codecs as cs +from tqdm import tqdm + + + +class VQMotionDataset(data.Dataset): + def __init__(self, dataset_name, window_size = 64, unit_length = 4): + self.window_size = window_size + self.unit_length = unit_length + self.dataset_name = dataset_name + + if dataset_name == 't2m': + self.data_root = './dataset/HumanML3D' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 22 + self.max_motion_length = 196 + self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' + + elif dataset_name == 'kit': + self.data_root = './dataset/KIT-ML' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 21 + + self.max_motion_length = 196 + self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' + + joints_num = self.joints_num + + mean = np.load(pjoin(self.meta_dir, 'mean.npy')) + std = np.load(pjoin(self.meta_dir, 'std.npy')) + + split_file = pjoin(self.data_root, 'train.txt') + + self.data = [] + self.lengths = [] + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + + for name in tqdm(id_list): + try: + motion = np.load(pjoin(self.motion_dir, name + '.npy')) + if motion.shape[0] < self.window_size: + continue + self.lengths.append(motion.shape[0] - self.window_size) + self.data.append(motion) + except: + # Some motion may not exist in KIT dataset + pass + + + self.mean = mean + self.std = std + print("Total number of motions {}".format(len(self.data))) + + def inv_transform(self, data): + return data * self.std + self.mean + + def compute_sampling_prob(self) : + + prob = np.array(self.lengths, dtype=np.float32) + prob /= np.sum(prob) + return prob + + def __len__(self): + return len(self.data) + + def __getitem__(self, item): + motion = self.data[item] + + idx = random.randint(0, len(motion) - self.window_size) + + motion = motion[idx:idx+self.window_size] + "Z Normalization" + motion = (motion - self.mean) / self.std + + return motion + +def DATALoader(dataset_name, + batch_size, + num_workers = 8, + window_size = 64, + unit_length = 4): + + trainSet = VQMotionDataset(dataset_name, window_size=window_size, unit_length=unit_length) + prob = trainSet.compute_sampling_prob() + sampler = torch.utils.data.WeightedRandomSampler(prob, num_samples = len(trainSet) * 1000, replacement=True) + train_loader = torch.utils.data.DataLoader(trainSet, + batch_size, + shuffle=True, + #sampler=sampler, + num_workers=num_workers, + #collate_fn=collate_fn, + drop_last = True) + + return train_loader + +def cycle(iterable): + while True: + for x in iterable: + yield x diff --git a/VQ-Trans/dataset/dataset_tokenize.py b/VQ-Trans/dataset/dataset_tokenize.py new file mode 100644 index 0000000000000000000000000000000000000000..641a02a75f2cfaadea45851cad2a95b39bfa1eae --- /dev/null +++ b/VQ-Trans/dataset/dataset_tokenize.py @@ -0,0 +1,117 @@ +import torch +from torch.utils import data +import numpy as np +from os.path import join as pjoin +import random +import codecs as cs +from tqdm import tqdm + + + +class VQMotionDataset(data.Dataset): + def __init__(self, dataset_name, feat_bias = 5, window_size = 64, unit_length = 8): + self.window_size = window_size + self.unit_length = unit_length + self.feat_bias = feat_bias + + self.dataset_name = dataset_name + min_motion_len = 40 if dataset_name =='t2m' else 24 + + if dataset_name == 't2m': + self.data_root = './dataset/HumanML3D' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 22 + radius = 4 + fps = 20 + self.max_motion_length = 196 + dim_pose = 263 + self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' + #kinematic_chain = paramUtil.t2m_kinematic_chain + elif dataset_name == 'kit': + self.data_root = './dataset/KIT-ML' + self.motion_dir = pjoin(self.data_root, 'new_joint_vecs') + self.text_dir = pjoin(self.data_root, 'texts') + self.joints_num = 21 + radius = 240 * 8 + fps = 12.5 + dim_pose = 251 + self.max_motion_length = 196 + self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta' + #kinematic_chain = paramUtil.kit_kinematic_chain + + joints_num = self.joints_num + + mean = np.load(pjoin(self.meta_dir, 'mean.npy')) + std = np.load(pjoin(self.meta_dir, 'std.npy')) + + split_file = pjoin(self.data_root, 'train.txt') + + data_dict = {} + id_list = [] + with cs.open(split_file, 'r') as f: + for line in f.readlines(): + id_list.append(line.strip()) + + new_name_list = [] + length_list = [] + for name in tqdm(id_list): + try: + motion = np.load(pjoin(self.motion_dir, name + '.npy')) + if (len(motion)) < min_motion_len or (len(motion) >= 200): + continue + + data_dict[name] = {'motion': motion, + 'length': len(motion), + 'name': name} + new_name_list.append(name) + length_list.append(len(motion)) + except: + # Some motion may not exist in KIT dataset + pass + + + self.mean = mean + self.std = std + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.name_list = new_name_list + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, item): + name = self.name_list[item] + data = self.data_dict[name] + motion, m_length = data['motion'], data['length'] + + m_length = (m_length // self.unit_length) * self.unit_length + + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx+m_length] + + "Z Normalization" + motion = (motion - self.mean) / self.std + + return motion, name + +def DATALoader(dataset_name, + batch_size = 1, + num_workers = 8, unit_length = 4) : + + train_loader = torch.utils.data.DataLoader(VQMotionDataset(dataset_name, unit_length=unit_length), + batch_size, + shuffle=True, + num_workers=num_workers, + #collate_fn=collate_fn, + drop_last = True) + + return train_loader + +def cycle(iterable): + while True: + for x in iterable: + yield x diff --git a/VQ-Trans/dataset/prepare/download_extractor.sh b/VQ-Trans/dataset/prepare/download_extractor.sh new file mode 100644 index 0000000000000000000000000000000000000000..b1c456e8311a59a1c8d86e85da5ddd3aa7e1f9a4 --- /dev/null +++ b/VQ-Trans/dataset/prepare/download_extractor.sh @@ -0,0 +1,15 @@ +rm -rf checkpoints +mkdir checkpoints +cd checkpoints +echo -e "Downloading extractors" +gdown --fuzzy https://drive.google.com/file/d/1o7RTDQcToJjTm9_mNWTyzvZvjTWpZfug/view +gdown --fuzzy https://drive.google.com/file/d/1tX79xk0fflp07EZ660Xz1RAFE33iEyJR/view + + +unzip t2m.zip +unzip kit.zip + +echo -e "Cleaning\n" +rm t2m.zip +rm kit.zip +echo -e "Downloading done!" \ No newline at end of file diff --git a/VQ-Trans/dataset/prepare/download_glove.sh b/VQ-Trans/dataset/prepare/download_glove.sh new file mode 100644 index 0000000000000000000000000000000000000000..058599aa32c9c97e0e3fc0a9658822e9c904955a --- /dev/null +++ b/VQ-Trans/dataset/prepare/download_glove.sh @@ -0,0 +1,9 @@ +echo -e "Downloading glove (in use by the evaluators)" +gdown --fuzzy https://drive.google.com/file/d/1bCeS6Sh_mLVTebxIgiUHgdPrroW06mb6/view?usp=sharing +rm -rf glove + +unzip glove.zip +echo -e "Cleaning\n" +rm glove.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/VQ-Trans/dataset/prepare/download_model.sh b/VQ-Trans/dataset/prepare/download_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..da32436f6efa93e0c14e1dd52f97068bd75956ab --- /dev/null +++ b/VQ-Trans/dataset/prepare/download_model.sh @@ -0,0 +1,12 @@ + +mkdir -p pretrained +cd pretrained/ + +echo -e "The pretrained model files will be stored in the 'pretrained' folder\n" +gdown 1LaOvwypF-jM2Axnq5dc-Iuvv3w_G-WDE + +unzip VQTrans_pretrained.zip +echo -e "Cleaning\n" +rm VQTrans_pretrained.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/VQ-Trans/dataset/prepare/download_smpl.sh b/VQ-Trans/dataset/prepare/download_smpl.sh new file mode 100644 index 0000000000000000000000000000000000000000..411325b509e891d96b859bf28f7b983005ca360a --- /dev/null +++ b/VQ-Trans/dataset/prepare/download_smpl.sh @@ -0,0 +1,13 @@ + +mkdir -p body_models +cd body_models/ + +echo -e "The smpl files will be stored in the 'body_models/smpl/' folder\n" +gdown 1INYlGA76ak_cKGzvpOV2Pe6RkYTlXTW2 +rm -rf smpl + +unzip smpl.zip +echo -e "Cleaning\n" +rm smpl.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/VQ-Trans/environment.yml b/VQ-Trans/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..cb0abb7f5c278d1eaee782c02abb3a46da736f90 --- /dev/null +++ b/VQ-Trans/environment.yml @@ -0,0 +1,121 @@ +name: VQTrans +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - blas=1.0=mkl + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2021.7.5=h06a4308_1 + - certifi=2021.5.30=py38h06a4308_0 + - cudatoolkit=10.1.243=h6bb024c_0 + - ffmpeg=4.3=hf484d3e_0 + - freetype=2.10.4=h5ab3b9f_0 + - gmp=6.2.1=h2531618_2 + - gnutls=3.6.15=he1e5248_0 + - intel-openmp=2021.3.0=h06a4308_3350 + - jpeg=9b=h024ee3a_2 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgomp=9.3.0=h5101ec6_17 + - libiconv=1.15=h63c8f33_5 + - libidn2=2.3.2=h7f8727e_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.2.0=h85742a9_0 + - libunistring=0.9.10=h27cfd23_0 + - libuv=1.40.0=h7b6447c_0 + - libwebp-base=1.2.0=h27cfd23_0 + - lz4-c=1.9.3=h295c915_1 + - mkl=2021.3.0=h06a4308_520 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.0=py38h42c9631_2 + - mkl_random=1.2.2=py38h51133e4_0 + - ncurses=6.2=he6710b0_1 + - nettle=3.7.3=hbbd107a_1 + - ninja=1.10.2=hff7bd54_1 + - numpy=1.20.3=py38hf144106_0 + - numpy-base=1.20.3=py38h74d4b33_0 + - olefile=0.46=py_0 + - openh264=2.1.0=hd408876_0 + - openjpeg=2.3.0=h05c96fa_1 + - openssl=1.1.1k=h27cfd23_0 + - pillow=8.3.1=py38h2c7a002_0 + - pip=21.0.1=py38h06a4308_0 + - python=3.8.11=h12debd9_0_cpython + - pytorch=1.8.1=py3.8_cuda10.1_cudnn7.6.3_0 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py38h06a4308_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - tk=8.6.10=hbc83047_0 + - torchaudio=0.8.1=py38 + - torchvision=0.9.1=py38_cu101 + - typing_extensions=3.10.0.0=pyh06a4308_0 + - wheel=0.37.0=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 + - pip: + - absl-py==0.13.0 + - backcall==0.2.0 + - cachetools==4.2.2 + - charset-normalizer==2.0.4 + - chumpy==0.70 + - cycler==0.10.0 + - decorator==5.0.9 + - google-auth==1.35.0 + - google-auth-oauthlib==0.4.5 + - grpcio==1.39.0 + - idna==3.2 + - imageio==2.9.0 + - ipdb==0.13.9 + - ipython==7.26.0 + - ipython-genutils==0.2.0 + - jedi==0.18.0 + - joblib==1.0.1 + - kiwisolver==1.3.1 + - markdown==3.3.4 + - matplotlib==3.4.3 + - matplotlib-inline==0.1.2 + - oauthlib==3.1.1 + - pandas==1.3.2 + - parso==0.8.2 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - prompt-toolkit==3.0.20 + - protobuf==3.17.3 + - ptyprocess==0.7.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pygments==2.10.0 + - pyparsing==2.4.7 + - python-dateutil==2.8.2 + - pytz==2021.1 + - pyyaml==5.4.1 + - requests==2.26.0 + - requests-oauthlib==1.3.0 + - rsa==4.7.2 + - scikit-learn==0.24.2 + - scipy==1.7.1 + - sklearn==0.0 + - smplx==0.1.28 + - tensorboard==2.6.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.0 + - threadpoolctl==2.2.0 + - toml==0.10.2 + - tqdm==4.62.2 + - traitlets==5.0.5 + - urllib3==1.26.6 + - wcwidth==0.2.5 + - werkzeug==2.0.1 + - git+https://github.com/openai/CLIP.git + - git+https://github.com/nghorbani/human_body_prior + - gdown + - moviepy \ No newline at end of file diff --git a/VQ-Trans/models/encdec.py b/VQ-Trans/models/encdec.py new file mode 100644 index 0000000000000000000000000000000000000000..ae72afaa5aa59ad67cadb38e0d83e420fc6edb09 --- /dev/null +++ b/VQ-Trans/models/encdec.py @@ -0,0 +1,67 @@ +import torch.nn as nn +from models.resnet import Resnet1D + +class Encoder(nn.Module): + def __init__(self, + input_emb_width = 3, + output_emb_width = 512, + down_t = 3, + stride_t = 2, + width = 512, + depth = 3, + dilation_growth_rate = 3, + activation='relu', + norm=None): + super().__init__() + + blocks = [] + filter_t, pad_t = stride_t * 2, stride_t // 2 + blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + + for i in range(down_t): + input_dim = width + block = nn.Sequential( + nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), + Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm), + ) + blocks.append(block) + blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) + +class Decoder(nn.Module): + def __init__(self, + input_emb_width = 3, + output_emb_width = 512, + down_t = 3, + stride_t = 2, + width = 512, + depth = 3, + dilation_growth_rate = 3, + activation='relu', + norm=None): + super().__init__() + blocks = [] + + filter_t, pad_t = stride_t * 2, stride_t // 2 + blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + for i in range(down_t): + out_dim = width + block = nn.Sequential( + Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv1d(width, out_dim, 3, 1, 1) + ) + blocks.append(block) + blocks.append(nn.Conv1d(width, width, 3, 1, 1)) + blocks.append(nn.ReLU()) + blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1)) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) + diff --git a/VQ-Trans/models/evaluator_wrapper.py b/VQ-Trans/models/evaluator_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4558a13ccc2ce0579540b8b77f958096e9984c --- /dev/null +++ b/VQ-Trans/models/evaluator_wrapper.py @@ -0,0 +1,92 @@ + +import torch +from os.path import join as pjoin +import numpy as np +from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo +from utils.word_vectorizer import POS_enumerator + +def build_models(opt): + movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) + text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, + pos_size=opt.dim_pos_ohot, + hidden_size=opt.dim_text_hidden, + output_size=opt.dim_coemb_hidden, + device=opt.device) + + motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, + hidden_size=opt.dim_motion_hidden, + output_size=opt.dim_coemb_hidden, + device=opt.device) + + checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), + map_location=opt.device) + movement_enc.load_state_dict(checkpoint['movement_encoder']) + text_enc.load_state_dict(checkpoint['text_encoder']) + motion_enc.load_state_dict(checkpoint['motion_encoder']) + print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) + return text_enc, motion_enc, movement_enc + + +class EvaluatorModelWrapper(object): + + def __init__(self, opt): + + if opt.dataset_name == 't2m': + opt.dim_pose = 263 + elif opt.dataset_name == 'kit': + opt.dim_pose = 251 + else: + raise KeyError('Dataset not Recognized!!!') + + opt.dim_word = 300 + opt.max_motion_length = 196 + opt.dim_pos_ohot = len(POS_enumerator) + opt.dim_motion_hidden = 1024 + opt.max_text_len = 20 + opt.dim_text_hidden = 512 + opt.dim_coemb_hidden = 512 + + # print(opt) + + self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) + self.opt = opt + self.device = opt.device + + self.text_encoder.to(opt.device) + self.motion_encoder.to(opt.device) + self.movement_encoder.to(opt.device) + + self.text_encoder.eval() + self.motion_encoder.eval() + self.movement_encoder.eval() + + # Please note that the results does not following the order of inputs + def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): + with torch.no_grad(): + word_embs = word_embs.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + motions = motions.detach().to(self.device).float() + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = m_lens // self.opt.unit_length + motion_embedding = self.motion_encoder(movements, m_lens) + + '''Text Encoding''' + text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) + return text_embedding, motion_embedding + + # Please note that the results does not following the order of inputs + def get_motion_embeddings(self, motions, m_lens): + with torch.no_grad(): + motions = motions.detach().to(self.device).float() + + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = m_lens // self.opt.unit_length + motion_embedding = self.motion_encoder(movements, m_lens) + return motion_embedding diff --git a/VQ-Trans/models/modules.py b/VQ-Trans/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4f06cd98d4f6029bd3df073095cf50498483d54a --- /dev/null +++ b/VQ-Trans/models/modules.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + +def init_weight(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + self.out_net.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return self.out_net(outputs) + + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size, device): + super(TextEncoderBiGRUCo, self).__init__() + self.device = device + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + self.output_net.apply(init_weight) + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size, hidden_size, output_size, device): + super(MotionEncoderBiGRUCo, self).__init__() + self.device = device + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size*2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.output_net.apply(init_weight) + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, inputs, m_lens): + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/VQ-Trans/models/pos_encoding.py b/VQ-Trans/models/pos_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..066be3e1f8a1636f7eaabd1c534b9c618ee3e9f8 --- /dev/null +++ b/VQ-Trans/models/pos_encoding.py @@ -0,0 +1,43 @@ +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +def PE1d_sincos(seq_length, dim): + """ + :param d_model: dimension of the model + :param length: length of positions + :return: length*d_model position matrix + """ + if dim % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(dim)) + pe = torch.zeros(seq_length, dim) + position = torch.arange(0, seq_length).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * + -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe.unsqueeze(1) + + +class PositionEmbedding(nn.Module): + """ + Absolute pos embedding (standard), learned. + """ + def __init__(self, seq_length, dim, dropout, grad=False): + super().__init__() + self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + # x.shape: bs, seq_len, feat_dim + l = x.shape[1] + x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape) + x = self.dropout(x.permute(1, 0, 2)) + return x + + \ No newline at end of file diff --git a/VQ-Trans/models/quantize_cnn.py b/VQ-Trans/models/quantize_cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..b796772749efda9a225bdcb0e7262791a972a710 --- /dev/null +++ b/VQ-Trans/models/quantize_cnn.py @@ -0,0 +1,415 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class QuantizeEMAReset(nn.Module): + def __init__(self, nb_code, code_dim, args): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = args.mu + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + if torch.cuda.is_available(): + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) + else: + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim)) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) # nb_code, w + code_count = code_onehot.sum(dim=-1) # nb_code + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code + + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + + self.codebook = usage * code_update + (1 - usage) * code_rand + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + + return perplexity + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + N, width, T = x.shape + + # Preprocess + x = self.preprocess(x) + + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity + + + +class Quantizer(nn.Module): + def __init__(self, n_e, e_dim, beta): + super(Quantizer, self).__init__() + + self.e_dim = e_dim + self.n_e = n_e + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + + N, width, T = z.shape + z = self.preprocess(z) + assert z.shape[-1] == self.e_dim + z_flattened = z.contiguous().view(-1, self.e_dim) + + # B x V + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + + # compute loss for embedding + loss = torch.mean((z_q - z.detach())**2) + self.beta * \ + torch.mean((z_q.detach() - z)**2) + + # preserve gradients + z_q = z + (z_q - z).detach() + z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) + return z_q, loss, perplexity + + def quantize(self, z): + + assert z.shape[-1] == self.e_dim + + # B x V + d = torch.sum(z ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.matmul(z, self.embedding.weight.t()) + # B x 1 + min_encoding_indices = torch.argmin(d, dim=1) + return min_encoding_indices + + def dequantize(self, indices): + + index_flattened = indices.view(-1) + z_q = self.embedding(index_flattened) + z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() + return z_q + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + + +class QuantizeReset(nn.Module): + def __init__(self, nb_code, code_dim, args): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.reset_codebook() + self.codebook = nn.Parameter(torch.randn(nb_code, code_dim)) + + def reset_codebook(self): + self.init = False + self.code_count = None + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = nn.Parameter(out[:self.nb_code]) + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_count = code_count # nb_code + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + + self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + + return perplexity + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + N, width, T = x.shape + # Preprocess + x = self.preprocess(x) + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity + +class QuantizeEMA(nn.Module): + def __init__(self, nb_code, code_dim, args): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = 0.99 + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda()) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) # nb_code, w + code_count = code_onehot.sum(dim=-1) # nb_code + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code + + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + + self.codebook = code_update + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + return perplexity + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) + return x + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + + def forward(self, x): + N, width, T = x.shape + + # Preprocess + x = self.preprocess(x) + + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) + + return x_d, commit_loss, perplexity \ No newline at end of file diff --git a/VQ-Trans/models/resnet.py b/VQ-Trans/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..062346e3ba2fc4d6ae5636f228c5b7565bdb62b7 --- /dev/null +++ b/VQ-Trans/models/resnet.py @@ -0,0 +1,82 @@ +import torch.nn as nn +import torch + +class nonlinearity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # swish + return x * torch.sigmoid(x) + +class ResConv1DBlock(nn.Module): + def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None): + super().__init__() + padding = dilation + self.norm = norm + if norm == "LN": + self.norm1 = nn.LayerNorm(n_in) + self.norm2 = nn.LayerNorm(n_in) + elif norm == "GN": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) + elif norm == "BN": + self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) + + else: + self.norm1 = nn.Identity() + self.norm2 = nn.Identity() + + if activation == "relu": + self.activation1 = nn.ReLU() + self.activation2 = nn.ReLU() + + elif activation == "silu": + self.activation1 = nonlinearity() + self.activation2 = nonlinearity() + + elif activation == "gelu": + self.activation1 = nn.GELU() + self.activation2 = nn.GELU() + + + + self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation) + self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,) + + + def forward(self, x): + x_orig = x + if self.norm == "LN": + x = self.norm1(x.transpose(-2, -1)) + x = self.activation1(x.transpose(-2, -1)) + else: + x = self.norm1(x) + x = self.activation1(x) + + x = self.conv1(x) + + if self.norm == "LN": + x = self.norm2(x.transpose(-2, -1)) + x = self.activation2(x.transpose(-2, -1)) + else: + x = self.norm2(x) + x = self.activation2(x) + + x = self.conv2(x) + x = x + x_orig + return x + +class Resnet1D(nn.Module): + def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None): + super().__init__() + + blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)] + if reverse_dilation: + blocks = blocks[::-1] + + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/VQ-Trans/models/rotation2xyz.py b/VQ-Trans/models/rotation2xyz.py new file mode 100644 index 0000000000000000000000000000000000000000..44f6cb6c3fd0fd263bd6256803b908e9e2b4184b --- /dev/null +++ b/VQ-Trans/models/rotation2xyz.py @@ -0,0 +1,92 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +import torch +import utils.rotation_conversions as geometry + + +from models.smpl import SMPL, JOINTSTYPE_ROOT +# from .get_model import JOINTSTYPES +JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] + + +class Rotation2xyz: + def __init__(self, device, dataset='amass'): + self.device = device + self.dataset = dataset + self.smpl_model = SMPL().eval().to(device) + + def __call__(self, x, mask, pose_rep, translation, glob, + jointstype, vertstrans, betas=None, beta=0, + glob_rot=None, get_rotations_back=False, **kwargs): + if pose_rep == "xyz": + return x + + if mask is None: + mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) + + if not glob and glob_rot is None: + raise TypeError("You must specify global rotation if glob is False") + + if jointstype not in JOINTSTYPES: + raise NotImplementedError("This jointstype is not implemented.") + + if translation: + x_translations = x[:, -1, :3] + x_rotations = x[:, :-1] + else: + x_rotations = x + + x_rotations = x_rotations.permute(0, 3, 1, 2) + nsamples, time, njoints, feats = x_rotations.shape + + # Compute rotations (convert only masked sequences output) + if pose_rep == "rotvec": + rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) + elif pose_rep == "rotmat": + rotations = x_rotations[mask].view(-1, njoints, 3, 3) + elif pose_rep == "rotquat": + rotations = geometry.quaternion_to_matrix(x_rotations[mask]) + elif pose_rep == "rot6d": + rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) + else: + raise NotImplementedError("No geometry for this one.") + + if not glob: + global_orient = torch.tensor(glob_rot, device=x.device) + global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) + global_orient = global_orient.repeat(len(rotations), 1, 1, 1) + else: + global_orient = rotations[:, 0] + rotations = rotations[:, 1:] + + if betas is None: + betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], + dtype=rotations.dtype, device=rotations.device) + betas[:, 1] = beta + # import ipdb; ipdb.set_trace() + out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) + + # get the desirable joints + joints = out[jointstype] + + x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) + x_xyz[~mask] = 0 + x_xyz[mask] = joints + + x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() + + # the first translation root at the origin on the prediction + if jointstype != "vertices": + rootindex = JOINTSTYPE_ROOT[jointstype] + x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] + + if translation and vertstrans: + # the first translation root at the origin + x_translations = x_translations - x_translations[:, :, [0]] + + # add the translation to all the joints + x_xyz = x_xyz + x_translations[:, None, :, :] + + if get_rotations_back: + return x_xyz, rotations, global_orient + else: + return x_xyz diff --git a/VQ-Trans/models/smpl.py b/VQ-Trans/models/smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..587f5419601a74df92c1e37263b28d4aa6a7c0a9 --- /dev/null +++ b/VQ-Trans/models/smpl.py @@ -0,0 +1,97 @@ +# This code is based on https://github.com/Mathux/ACTOR.git +import numpy as np +import torch + +import contextlib + +from smplx import SMPLLayer as _SMPLLayer +from smplx.lbs import vertices2joints + + +# action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38] +# change 0 and 8 +action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] + +from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA + +JOINTSTYPE_ROOT = {"a2m": 0, # action2motion + "smpl": 0, + "a2mpl": 0, # set(smpl, a2m) + "vibe": 8} # 0 is the 8 position: OP MidHip below + +JOINT_MAP = { + 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, + 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, + 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, + 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, + 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, + 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, + 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, + 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, + 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, + 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, + 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, + 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, + 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, + 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, + 'Spine (H36M)': 51, 'Jaw (H36M)': 52, + 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, + 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 +} + +JOINT_NAMES = [ + 'OP Nose', 'OP Neck', 'OP RShoulder', + 'OP RElbow', 'OP RWrist', 'OP LShoulder', + 'OP LElbow', 'OP LWrist', 'OP MidHip', + 'OP RHip', 'OP RKnee', 'OP RAnkle', + 'OP LHip', 'OP LKnee', 'OP LAnkle', + 'OP REye', 'OP LEye', 'OP REar', + 'OP LEar', 'OP LBigToe', 'OP LSmallToe', + 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', + 'Right Ankle', 'Right Knee', 'Right Hip', + 'Left Hip', 'Left Knee', 'Left Ankle', + 'Right Wrist', 'Right Elbow', 'Right Shoulder', + 'Left Shoulder', 'Left Elbow', 'Left Wrist', + 'Neck (LSP)', 'Top of Head (LSP)', + 'Pelvis (MPII)', 'Thorax (MPII)', + 'Spine (H36M)', 'Jaw (H36M)', + 'Head (H36M)', 'Nose', 'Left Eye', + 'Right Eye', 'Left Ear', 'Right Ear' +] + + +# adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints +class SMPL(_SMPLLayer): + """ Extension of the official SMPL implementation to support more joints """ + + def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): + kwargs["model_path"] = model_path + + # remove the verbosity for the 10-shapes beta parameters + with contextlib.redirect_stdout(None): + super(SMPL, self).__init__(**kwargs) + + J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) + self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) + vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) + a2m_indexes = vibe_indexes[action2motion_joints] + smpl_indexes = np.arange(24) + a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) + + self.maps = {"vibe": vibe_indexes, + "a2m": a2m_indexes, + "smpl": smpl_indexes, + "a2mpl": a2mpl_indexes} + + def forward(self, *args, **kwargs): + smpl_output = super(SMPL, self).forward(*args, **kwargs) + + extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) + all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) + + output = {"vertices": smpl_output.vertices} + + for joinstype, indexes in self.maps.items(): + output[joinstype] = all_joints[:, indexes] + + return output \ No newline at end of file diff --git a/VQ-Trans/models/t2m_trans.py b/VQ-Trans/models/t2m_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..54bd0a485d7e8dbeaaac91d049f63ebd136cb074 --- /dev/null +++ b/VQ-Trans/models/t2m_trans.py @@ -0,0 +1,211 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.distributions import Categorical +import models.pos_encoding as pos_encoding + +class Text2Motion_Transformer(nn.Module): + + def __init__(self, + num_vq=1024, + embed_dim=512, + clip_dim=512, + block_size=16, + num_layers=2, + n_head=8, + drop_out_rate=0.1, + fc_rate=4): + super().__init__() + self.trans_base = CrossCondTransBase(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) + self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) + self.block_size = block_size + self.num_vq = num_vq + + def get_block_size(self): + return self.block_size + + def forward(self, idxs, clip_feature): + feat = self.trans_base(idxs, clip_feature) + logits = self.trans_head(feat) + return logits + + def sample(self, clip_feature, if_categorial=False): + for k in range(self.block_size): + if k == 0: + x = [] + else: + x = xs + logits = self.forward(x, clip_feature) + logits = logits[:, -1, :] + probs = F.softmax(logits, dim=-1) + if if_categorial: + dist = Categorical(probs) + idx = dist.sample() + if idx == self.num_vq: + break + idx = idx.unsqueeze(-1) + else: + _, idx = torch.topk(probs, k=1, dim=-1) + if idx[0] == self.num_vq: + break + # append to the sequence and continue + if k == 0: + xs = idx + else: + xs = torch.cat((xs, idx), dim=1) + + if k == self.block_size - 1: + return xs[:, :-1] + return xs + +class CausalCrossConditionalSelfAttention(nn.Module): + + def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1): + super().__init__() + assert embed_dim % 8 == 0 + # key, query, value projections for all heads + self.key = nn.Linear(embed_dim, embed_dim) + self.query = nn.Linear(embed_dim, embed_dim) + self.value = nn.Linear(embed_dim, embed_dim) + + self.attn_drop = nn.Dropout(drop_out_rate) + self.resid_drop = nn.Dropout(drop_out_rate) + + self.proj = nn.Linear(embed_dim, embed_dim) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) + self.n_head = n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + +class Block(nn.Module): + + def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.ln2 = nn.LayerNorm(embed_dim) + self.attn = CausalCrossConditionalSelfAttention(embed_dim, block_size, n_head, drop_out_rate) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, fc_rate * embed_dim), + nn.GELU(), + nn.Linear(fc_rate * embed_dim, embed_dim), + nn.Dropout(drop_out_rate), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class CrossCondTransBase(nn.Module): + + def __init__(self, + num_vq=1024, + embed_dim=512, + clip_dim=512, + block_size=16, + num_layers=2, + n_head=8, + drop_out_rate=0.1, + fc_rate=4): + super().__init__() + self.tok_emb = nn.Embedding(num_vq + 2, embed_dim) + self.cond_emb = nn.Linear(clip_dim, embed_dim) + self.pos_embedding = nn.Embedding(block_size, embed_dim) + self.drop = nn.Dropout(drop_out_rate) + # transformer block + self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)]) + self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False) + + self.block_size = block_size + + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, clip_feature): + if len(idx) == 0: + token_embeddings = self.cond_emb(clip_feature).unsqueeze(1) + else: + b, t = idx.size() + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + # forward the Trans model + token_embeddings = self.tok_emb(idx) + token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1) + + x = self.pos_embed(token_embeddings) + x = self.blocks(x) + + return x + + +class CrossCondTransHead(nn.Module): + + def __init__(self, + num_vq=1024, + embed_dim=512, + block_size=16, + num_layers=2, + n_head=8, + drop_out_rate=0.1, + fc_rate=4): + super().__init__() + + self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)]) + self.ln_f = nn.LayerNorm(embed_dim) + self.head = nn.Linear(embed_dim, num_vq + 1, bias=False) + self.block_size = block_size + + self.apply(self._init_weights) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x): + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + return logits + + + + + + diff --git a/VQ-Trans/models/vqvae.py b/VQ-Trans/models/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6c940674d460853e8418514bf2306f774689fd --- /dev/null +++ b/VQ-Trans/models/vqvae.py @@ -0,0 +1,118 @@ +import torch.nn as nn +from models.encdec import Encoder, Decoder +from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset + + +class VQVAE_251(nn.Module): + def __init__(self, + args, + nb_code=1024, + code_dim=512, + output_emb_width=512, + down_t=3, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation='relu', + norm=None): + + super().__init__() + self.code_dim = code_dim + self.num_code = nb_code + self.quant = args.quantizer + self.encoder = Encoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) + self.decoder = Decoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) + if args.quantizer == "ema_reset": + self.quantizer = QuantizeEMAReset(nb_code, code_dim, args) + elif args.quantizer == "orig": + self.quantizer = Quantizer(nb_code, code_dim, 1.0) + elif args.quantizer == "ema": + self.quantizer = QuantizeEMA(nb_code, code_dim, args) + elif args.quantizer == "reset": + self.quantizer = QuantizeReset(nb_code, code_dim, args) + + + def preprocess(self, x): + # (bs, T, Jx3) -> (bs, Jx3, T) + x = x.permute(0,2,1).float() + return x + + + def postprocess(self, x): + # (bs, Jx3, T) -> (bs, T, Jx3) + x = x.permute(0,2,1) + return x + + + def encode(self, x): + N, T, _ = x.shape + x_in = self.preprocess(x) + x_encoder = self.encoder(x_in) + x_encoder = self.postprocess(x_encoder) + x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C) + code_idx = self.quantizer.quantize(x_encoder) + code_idx = code_idx.view(N, -1) + return code_idx + + + def forward(self, x): + + x_in = self.preprocess(x) + # Encode + x_encoder = self.encoder(x_in) + + ## quantization + x_quantized, loss, perplexity = self.quantizer(x_encoder) + + ## decoder + x_decoder = self.decoder(x_quantized) + x_out = self.postprocess(x_decoder) + return x_out, loss, perplexity + + + def forward_decoder(self, x): + x_d = self.quantizer.dequantize(x) + x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous() + + # decoder + x_decoder = self.decoder(x_d) + x_out = self.postprocess(x_decoder) + return x_out + + + +class HumanVQVAE(nn.Module): + def __init__(self, + args, + nb_code=512, + code_dim=512, + output_emb_width=512, + down_t=3, + stride_t=2, + width=512, + depth=3, + dilation_growth_rate=3, + activation='relu', + norm=None): + + super().__init__() + + self.nb_joints = 21 if args.dataname == 'kit' else 22 + self.vqvae = VQVAE_251(args, nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) + + def encode(self, x): + b, t, c = x.size() + quants = self.vqvae.encode(x) # (N, T) + return quants + + def forward(self, x): + + x_out, loss, perplexity = self.vqvae(x) + + return x_out, loss, perplexity + + def forward_decoder(self, x): + x_out = self.vqvae.forward_decoder(x) + return x_out + \ No newline at end of file diff --git a/VQ-Trans/options/get_eval_option.py b/VQ-Trans/options/get_eval_option.py new file mode 100644 index 0000000000000000000000000000000000000000..d0989ba1a8116068753ada2cb1806744e4512447 --- /dev/null +++ b/VQ-Trans/options/get_eval_option.py @@ -0,0 +1,83 @@ +from argparse import Namespace +import re +from os.path import join as pjoin + + +def is_float(numStr): + flag = False + numStr = str(numStr).strip().lstrip('-').lstrip('+') + try: + reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') + res = reg.match(str(numStr)) + if res: + flag = True + except Exception as ex: + print("is_float() - error: " + str(ex)) + return flag + + +def is_number(numStr): + flag = False + numStr = str(numStr).strip().lstrip('-').lstrip('+') + if str(numStr).isdigit(): + flag = True + return flag + + +def get_opt(opt_path, device): + opt = Namespace() + opt_dict = vars(opt) + + skip = ('-------------- End ----------------', + '------------ Options -------------', + '\n') + print('Reading', opt_path) + with open(opt_path) as f: + for line in f: + if line.strip() not in skip: + # print(line.strip()) + key, value = line.strip().split(': ') + if value in ('True', 'False'): + opt_dict[key] = (value == 'True') + # print(key, value) + elif is_float(value): + opt_dict[key] = float(value) + elif is_number(value): + opt_dict[key] = int(value) + else: + opt_dict[key] = str(value) + + # print(opt) + opt_dict['which_epoch'] = 'finest' + opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) + opt.model_dir = pjoin(opt.save_root, 'model') + opt.meta_dir = pjoin(opt.save_root, 'meta') + + if opt.dataset_name == 't2m': + opt.data_root = './dataset/HumanML3D/' + opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') + opt.text_dir = pjoin(opt.data_root, 'texts') + opt.joints_num = 22 + opt.dim_pose = 263 + opt.max_motion_length = 196 + opt.max_motion_frame = 196 + opt.max_motion_token = 55 + elif opt.dataset_name == 'kit': + opt.data_root = './dataset/KIT-ML/' + opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') + opt.text_dir = pjoin(opt.data_root, 'texts') + opt.joints_num = 21 + opt.dim_pose = 251 + opt.max_motion_length = 196 + opt.max_motion_frame = 196 + opt.max_motion_token = 55 + else: + raise KeyError('Dataset not recognized') + + opt.dim_word = 300 + opt.num_classes = 200 // opt.unit_length + opt.is_train = False + opt.is_continue = False + opt.device = device + + return opt \ No newline at end of file diff --git a/VQ-Trans/options/option_transformer.py b/VQ-Trans/options/option_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf48ce1fdac663ec44419d67721ac268806f8127 --- /dev/null +++ b/VQ-Trans/options/option_transformer.py @@ -0,0 +1,68 @@ +import argparse + +def get_args_parser(): + parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for Amass', + add_help=True, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + ## dataloader + + parser.add_argument('--dataname', type=str, default='kit', help='dataset directory') + parser.add_argument('--batch-size', default=128, type=int, help='batch size') + parser.add_argument('--fps', default=[20], nargs="+", type=int, help='frames per second') + parser.add_argument('--seq-len', type=int, default=64, help='training motion length') + + ## optimization + parser.add_argument('--total-iter', default=100000, type=int, help='number of total iterations to run') + parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup') + parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate') + parser.add_argument('--lr-scheduler', default=[60000], nargs="+", type=int, help="learning rate schedule (iterations)") + parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay") + + parser.add_argument('--weight-decay', default=1e-6, type=float, help='weight decay') + parser.add_argument('--decay-option',default='all', type=str, choices=['all', 'noVQ'], help='disable weight decay on codebook') + parser.add_argument('--optimizer',default='adamw', type=str, choices=['adam', 'adamw'], help='disable weight decay on codebook') + + ## vqvae arch + parser.add_argument("--code-dim", type=int, default=512, help="embedding dimension") + parser.add_argument("--nb-code", type=int, default=512, help="nb of embedding") + parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook") + parser.add_argument("--down-t", type=int, default=3, help="downsampling rate") + parser.add_argument("--stride-t", type=int, default=2, help="stride size") + parser.add_argument("--width", type=int, default=512, help="width of the network") + parser.add_argument("--depth", type=int, default=3, help="depth of the network") + parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate") + parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width") + parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory') + + ## gpt arch + parser.add_argument("--block-size", type=int, default=25, help="seq len") + parser.add_argument("--embed-dim-gpt", type=int, default=512, help="embedding dimension") + parser.add_argument("--clip-dim", type=int, default=512, help="latent dimension in the clip feature") + parser.add_argument("--num-layers", type=int, default=2, help="nb of transformer layers") + parser.add_argument("--n-head-gpt", type=int, default=8, help="nb of heads") + parser.add_argument("--ff-rate", type=int, default=4, help="feedforward size") + parser.add_argument("--drop-out-rate", type=float, default=0.1, help="dropout ratio in the pos encoding") + + ## quantizer + parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport") + parser.add_argument('--quantbeta', type=float, default=1.0, help='dataset directory') + + ## resume + parser.add_argument("--resume-pth", type=str, default=None, help='resume vq pth') + parser.add_argument("--resume-trans", type=str, default=None, help='resume gpt pth') + + + ## output directory + parser.add_argument('--out-dir', type=str, default='output_GPT_Final/', help='output directory') + parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir') + parser.add_argument('--vq-name', type=str, default='exp_debug', help='name of the generated dataset .npy, will create a file inside out-dir') + ## other + parser.add_argument('--print-iter', default=200, type=int, help='print frequency') + parser.add_argument('--eval-iter', default=5000, type=int, help='evaluation frequency') + parser.add_argument('--seed', default=123, type=int, help='seed for initializing training. ') + parser.add_argument("--if-maxtest", action='store_true', help="test in max") + parser.add_argument('--pkeep', type=float, default=1.0, help='keep rate for gpt training') + + + return parser.parse_args() \ No newline at end of file diff --git a/VQ-Trans/options/option_vq.py b/VQ-Trans/options/option_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..08a53ff1270facc10ab44ec0647e673ed1336d0d --- /dev/null +++ b/VQ-Trans/options/option_vq.py @@ -0,0 +1,61 @@ +import argparse + +def get_args_parser(): + parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for AIST', + add_help=True, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + ## dataloader + parser.add_argument('--dataname', type=str, default='kit', help='dataset directory') + parser.add_argument('--batch-size', default=128, type=int, help='batch size') + parser.add_argument('--window-size', type=int, default=64, help='training motion length') + + ## optimization + parser.add_argument('--total-iter', default=200000, type=int, help='number of total iterations to run') + parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup') + parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate') + parser.add_argument('--lr-scheduler', default=[50000, 400000], nargs="+", type=int, help="learning rate schedule (iterations)") + parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay") + + parser.add_argument('--weight-decay', default=0.0, type=float, help='weight decay') + parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss") + parser.add_argument('--loss-vel', type=float, default=0.1, help='hyper-parameter for the velocity loss') + parser.add_argument('--recons-loss', type=str, default='l2', help='reconstruction loss') + + ## vqvae arch + parser.add_argument("--code-dim", type=int, default=512, help="embedding dimension") + parser.add_argument("--nb-code", type=int, default=512, help="nb of embedding") + parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook") + parser.add_argument("--down-t", type=int, default=2, help="downsampling rate") + parser.add_argument("--stride-t", type=int, default=2, help="stride size") + parser.add_argument("--width", type=int, default=512, help="width of the network") + parser.add_argument("--depth", type=int, default=3, help="depth of the network") + parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate") + parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width") + parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory') + parser.add_argument('--vq-norm', type=str, default=None, help='dataset directory') + + ## quantizer + parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport") + parser.add_argument('--beta', type=float, default=1.0, help='commitment loss in standard VQ') + + ## resume + parser.add_argument("--resume-pth", type=str, default=None, help='resume pth for VQ') + parser.add_argument("--resume-gpt", type=str, default=None, help='resume pth for GPT') + + + ## output directory + parser.add_argument('--out-dir', type=str, default='output_vqfinal/', help='output directory') + parser.add_argument('--results-dir', type=str, default='visual_results/', help='output directory') + parser.add_argument('--visual-name', type=str, default='baseline', help='output directory') + parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir') + ## other + parser.add_argument('--print-iter', default=200, type=int, help='print frequency') + parser.add_argument('--eval-iter', default=1000, type=int, help='evaluation frequency') + parser.add_argument('--seed', default=123, type=int, help='seed for initializing training.') + + parser.add_argument('--vis-gt', action='store_true', help='whether visualize GT motions') + parser.add_argument('--nb-vis', default=20, type=int, help='nb of visualizations') + + + return parser.parse_args() \ No newline at end of file diff --git a/VQ-Trans/output/23cb7d0e26bb1646b3d386331971449c_pred.pt b/VQ-Trans/output/23cb7d0e26bb1646b3d386331971449c_pred.pt new file mode 100644 index 0000000000000000000000000000000000000000..9a076444c9a447919ab0249568c34bbbbed7a005 --- /dev/null +++ b/VQ-Trans/output/23cb7d0e26bb1646b3d386331971449c_pred.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e05998e1d4ac1eebcba89ec989112b6fcdb55c8cceeef2faea7cb564a381525 +size 16206213 diff --git a/VQ-Trans/output/90dd3007b93da07eca7527c836b4d6d0_pred.pt b/VQ-Trans/output/90dd3007b93da07eca7527c836b4d6d0_pred.pt new file mode 100644 index 0000000000000000000000000000000000000000..da6dec12adce86c06255d273dac9e7018119edbd --- /dev/null +++ b/VQ-Trans/output/90dd3007b93da07eca7527c836b4d6d0_pred.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a24b9eecef62fc4a498c45e9342b265304191b9b0bd0f673bf3e029a6524bedd +size 16206213 diff --git a/VQ-Trans/output/c3785325ba8f17ce7427b43d49903e51_pred.pt b/VQ-Trans/output/c3785325ba8f17ce7427b43d49903e51_pred.pt new file mode 100644 index 0000000000000000000000000000000000000000..d72827a2475f3fb6c5be9f5f32d0c7eb6f2a87cf --- /dev/null +++ b/VQ-Trans/output/c3785325ba8f17ce7427b43d49903e51_pred.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3513af257c3ffa7b2168f02288a9d3349648efbfba4dcf88e00e9bca1e850855 +size 10584005 diff --git a/VQ-Trans/pyrender b/VQ-Trans/pyrender new file mode 160000 index 0000000000000000000000000000000000000000..a59963ef890891656fd17c90e12d663233dcaa99 --- /dev/null +++ b/VQ-Trans/pyrender @@ -0,0 +1 @@ +Subproject commit a59963ef890891656fd17c90e12d663233dcaa99 diff --git a/VQ-Trans/render_final.py b/VQ-Trans/render_final.py new file mode 100644 index 0000000000000000000000000000000000000000..41b3bfdb2e6bff74aeaceb8f1a7ebac9dc1acaba --- /dev/null +++ b/VQ-Trans/render_final.py @@ -0,0 +1,194 @@ +from models.rotation2xyz import Rotation2xyz +import numpy as np +from trimesh import Trimesh +import os +os.environ['PYOPENGL_PLATFORM'] = "osmesa" + +import torch +from visualize.simplify_loc2rot import joints2smpl +import pyrender +import matplotlib.pyplot as plt + +import io +import imageio +from shapely import geometry +import trimesh +from pyrender.constants import RenderFlags +import math +# import ffmpeg +from PIL import Image + +class WeakPerspectiveCamera(pyrender.Camera): + def __init__(self, + scale, + translation, + znear=pyrender.camera.DEFAULT_Z_NEAR, + zfar=None, + name=None): + super(WeakPerspectiveCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + self.scale = scale + self.translation = translation + + def get_projection_matrix(self, width=None, height=None): + P = np.eye(4) + P[0, 0] = self.scale[0] + P[1, 1] = self.scale[1] + P[0, 3] = self.translation[0] * self.scale[0] + P[1, 3] = -self.translation[1] * self.scale[1] + P[2, 2] = -1 + return P + +def render(motions, outdir='test_vis', device_id=0, name=None, pred=True): + frames, njoints, nfeats = motions.shape + MINS = motions.min(axis=0).min(axis=0) + MAXS = motions.max(axis=0).max(axis=0) + + height_offset = MINS[1] + motions[:, :, 1] -= height_offset + trajec = motions[:, 0, [0, 2]] + + j2s = joints2smpl(num_frames=frames, device_id=0, cuda=True) + rot2xyz = Rotation2xyz(device=torch.device("cuda:0")) + faces = rot2xyz.smpl_model.faces + + if (not os.path.exists(outdir + name+'_pred.pt') and pred) or (not os.path.exists(outdir + name+'_gt.pt') and not pred): + print(f'Running SMPLify, it may take a few minutes.') + motion_tensor, opt_dict = j2s.joint2smpl(motions) # [nframes, njoints, 3] + + vertices = rot2xyz(torch.tensor(motion_tensor).clone(), mask=None, + pose_rep='rot6d', translation=True, glob=True, + jointstype='vertices', + vertstrans=True) + + if pred: + torch.save(vertices, outdir + name+'_pred.pt') + else: + torch.save(vertices, outdir + name+'_gt.pt') + else: + if pred: + vertices = torch.load(outdir + name+'_pred.pt') + else: + vertices = torch.load(outdir + name+'_gt.pt') + frames = vertices.shape[3] # shape: 1, nb_frames, 3, nb_joints + print (vertices.shape) + MINS = torch.min(torch.min(vertices[0], axis=0)[0], axis=1)[0] + MAXS = torch.max(torch.max(vertices[0], axis=0)[0], axis=1)[0] + # vertices[:,:,1,:] -= MINS[1] + 1e-5 + + + out_list = [] + + minx = MINS[0] - 0.5 + maxx = MAXS[0] + 0.5 + minz = MINS[2] - 0.5 + maxz = MAXS[2] + 0.5 + polygon = geometry.Polygon([[minx, minz], [minx, maxz], [maxx, maxz], [maxx, minz]]) + polygon_mesh = trimesh.creation.extrude_polygon(polygon, 1e-5) + + vid = [] + for i in range(frames): + if i % 10 == 0: + print(i) + + mesh = Trimesh(vertices=vertices[0, :, :, i].squeeze().tolist(), faces=faces) + + base_color = (0.11, 0.53, 0.8, 0.5) + ## OPAQUE rendering without alpha + ## BLEND rendering consider alpha + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.7, + alphaMode='OPAQUE', + baseColorFactor=base_color + ) + + + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + polygon_mesh.visual.face_colors = [0, 0, 0, 0.21] + polygon_render = pyrender.Mesh.from_trimesh(polygon_mesh, smooth=False) + + bg_color = [1, 1, 1, 0.8] + scene = pyrender.Scene(bg_color=bg_color, ambient_light=(0.4, 0.4, 0.4)) + + sx, sy, tx, ty = [0.75, 0.75, 0, 0.10] + + camera = pyrender.PerspectiveCamera(yfov=(np.pi / 3.0)) + + light = pyrender.DirectionalLight(color=[1,1,1], intensity=300) + + scene.add(mesh) + + c = np.pi / 2 + + scene.add(polygon_render, pose=np.array([[ 1, 0, 0, 0], + + [ 0, np.cos(c), -np.sin(c), MINS[1].cpu().numpy()], + + [ 0, np.sin(c), np.cos(c), 0], + + [ 0, 0, 0, 1]])) + + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + scene.add(light, pose=light_pose.copy()) + + light_pose[:3, 3] = [0, 1, 1] + scene.add(light, pose=light_pose.copy()) + + light_pose[:3, 3] = [1, 1, 2] + scene.add(light, pose=light_pose.copy()) + + + c = -np.pi / 6 + + scene.add(camera, pose=[[ 1, 0, 0, (minx+maxx).cpu().numpy()/2], + + [ 0, np.cos(c), -np.sin(c), 1.5], + + [ 0, np.sin(c), np.cos(c), max(4, minz.cpu().numpy()+(1.5-MINS[1].cpu().numpy())*2, (maxx-minx).cpu().numpy())], + + [ 0, 0, 0, 1] + ]) + + # render scene + r = pyrender.OffscreenRenderer(960, 960) + + color, _ = r.render(scene, flags=RenderFlags.RGBA) + # Image.fromarray(color).save(outdir+'/'+name+'_'+str(i)+'.png') + + vid.append(color) + + r.delete() + + out = np.stack(vid, axis=0) + if pred: + imageio.mimsave(outdir + name+'_pred.gif', out, fps=20) + else: + imageio.mimsave(outdir + name+'_gt.gif', out, fps=20) + + + + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--filedir", type=str, default=None, help='motion npy file dir') + parser.add_argument('--motion-list', default=None, nargs="+", type=str, help="motion name list") + args = parser.parse_args() + + filename_list = args.motion_list + filedir = args.filedir + + for filename in filename_list: + motions = np.load(filedir + filename+'_pred.npy') + print('pred', motions.shape, filename) + render(motions[0], outdir=filedir, device_id=0, name=filename, pred=True) + + motions = np.load(filedir + filename+'_gt.npy') + print('gt', motions.shape, filename) + render(motions[0], outdir=filedir, device_id=0, name=filename, pred=False) diff --git a/VQ-Trans/train_t2m_trans.py b/VQ-Trans/train_t2m_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..8da444f87aa7ca71cd8bc3604868cf30a6c70e02 --- /dev/null +++ b/VQ-Trans/train_t2m_trans.py @@ -0,0 +1,191 @@ +import os +import torch +import numpy as np + +from torch.utils.tensorboard import SummaryWriter +from os.path import join as pjoin +from torch.distributions import Categorical +import json +import clip + +import options.option_transformer as option_trans +import models.vqvae as vqvae +import utils.utils_model as utils_model +import utils.eval_trans as eval_trans +from dataset import dataset_TM_train +from dataset import dataset_TM_eval +from dataset import dataset_tokenize +import models.t2m_trans as trans +from options.get_eval_option import get_opt +from models.evaluator_wrapper import EvaluatorModelWrapper +import warnings +warnings.filterwarnings('ignore') + +##### ---- Exp dirs ---- ##### +args = option_trans.get_args_parser() +torch.manual_seed(args.seed) + +args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') +args.vq_dir= os.path.join("./dataset/KIT-ML" if args.dataname == 'kit' else "./dataset/HumanML3D", f'{args.vq_name}') +os.makedirs(args.out_dir, exist_ok = True) +os.makedirs(args.vq_dir, exist_ok = True) + +##### ---- Logger ---- ##### +logger = utils_model.get_logger(args.out_dir) +writer = SummaryWriter(args.out_dir) +logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) + +##### ---- Dataloader ---- ##### +train_loader_token = dataset_tokenize.DATALoader(args.dataname, 1, unit_length=2**args.down_t) + +from utils.word_vectorizer import WordVectorizer +w_vectorizer = WordVectorizer('./glove', 'our_vab') +val_loader = dataset_TM_eval.DATALoader(args.dataname, False, 32, w_vectorizer) + +dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' + +wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) +eval_wrapper = EvaluatorModelWrapper(wrapper_opt) + +##### ---- Network ---- ##### +clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False, download_root='/apdcephfs_cq2/share_1290939/maelyszhang/.cache/clip') # Must set jit=False for training +clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16 +clip_model.eval() +for p in clip_model.parameters(): + p.requires_grad = False + +net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers + args.nb_code, + args.code_dim, + args.output_emb_width, + args.down_t, + args.stride_t, + args.width, + args.depth, + args.dilation_growth_rate) + + +trans_encoder = trans.Text2Motion_Transformer(num_vq=args.nb_code, + embed_dim=args.embed_dim_gpt, + clip_dim=args.clip_dim, + block_size=args.block_size, + num_layers=args.num_layers, + n_head=args.n_head_gpt, + drop_out_rate=args.drop_out_rate, + fc_rate=args.ff_rate) + + +print ('loading checkpoint from {}'.format(args.resume_pth)) +ckpt = torch.load(args.resume_pth, map_location='cpu') +net.load_state_dict(ckpt['net'], strict=True) +net.eval() +net.cuda() + +if args.resume_trans is not None: + print ('loading transformer checkpoint from {}'.format(args.resume_trans)) + ckpt = torch.load(args.resume_trans, map_location='cpu') + trans_encoder.load_state_dict(ckpt['trans'], strict=True) +trans_encoder.train() +trans_encoder.cuda() + +##### ---- Optimizer & Scheduler ---- ##### +optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer) +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma) + +##### ---- Optimization goals ---- ##### +loss_ce = torch.nn.CrossEntropyLoss() + +nb_iter, avg_loss_cls, avg_acc = 0, 0., 0. +right_num = 0 +nb_sample_train = 0 + +##### ---- get code ---- ##### +for batch in train_loader_token: + pose, name = batch + bs, seq = pose.shape[0], pose.shape[1] + + pose = pose.cuda().float() # bs, nb_joints, joints_dim, seq_len + target = net.encode(pose) + target = target.cpu().numpy() + np.save(pjoin(args.vq_dir, name[0] +'.npy'), target) + + +train_loader = dataset_TM_train.DATALoader(args.dataname, args.batch_size, args.nb_code, args.vq_name, unit_length=2**args.down_t) +train_loader_iter = dataset_TM_train.cycle(train_loader) + + +##### ---- Training ---- ##### +best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_transformer(args.out_dir, val_loader, net, trans_encoder, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, clip_model=clip_model, eval_wrapper=eval_wrapper) +while nb_iter <= args.total_iter: + + batch = next(train_loader_iter) + clip_text, m_tokens, m_tokens_len = batch + m_tokens, m_tokens_len = m_tokens.cuda(), m_tokens_len.cuda() + bs = m_tokens.shape[0] + target = m_tokens # (bs, 26) + target = target.cuda() + + text = clip.tokenize(clip_text, truncate=True).cuda() + + feat_clip_text = clip_model.encode_text(text).float() + + input_index = target[:,:-1] + + if args.pkeep == -1: + proba = np.random.rand(1)[0] + mask = torch.bernoulli(proba * torch.ones(input_index.shape, + device=input_index.device)) + else: + mask = torch.bernoulli(args.pkeep * torch.ones(input_index.shape, + device=input_index.device)) + mask = mask.round().to(dtype=torch.int64) + r_indices = torch.randint_like(input_index, args.nb_code) + a_indices = mask*input_index+(1-mask)*r_indices + + cls_pred = trans_encoder(a_indices, feat_clip_text) + cls_pred = cls_pred.contiguous() + + loss_cls = 0.0 + for i in range(bs): + # loss function (26), (26, 513) + loss_cls += loss_ce(cls_pred[i][:m_tokens_len[i] + 1], target[i][:m_tokens_len[i] + 1]) / bs + + # Accuracy + probs = torch.softmax(cls_pred[i][:m_tokens_len[i] + 1], dim=-1) + + if args.if_maxtest: + _, cls_pred_index = torch.max(probs, dim=-1) + + else: + dist = Categorical(probs) + cls_pred_index = dist.sample() + right_num += (cls_pred_index.flatten(0) == target[i][:m_tokens_len[i] + 1].flatten(0)).sum().item() + + ## global loss + optimizer.zero_grad() + loss_cls.backward() + optimizer.step() + scheduler.step() + + avg_loss_cls = avg_loss_cls + loss_cls.item() + nb_sample_train = nb_sample_train + (m_tokens_len + 1).sum().item() + + nb_iter += 1 + if nb_iter % args.print_iter == 0 : + avg_loss_cls = avg_loss_cls / args.print_iter + avg_acc = right_num * 100 / nb_sample_train + writer.add_scalar('./Loss/train', avg_loss_cls, nb_iter) + writer.add_scalar('./ACC/train', avg_acc, nb_iter) + msg = f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}, ACC. {avg_acc:.4f}" + logger.info(msg) + avg_loss_cls = 0. + right_num = 0 + nb_sample_train = 0 + + if nb_iter % args.eval_iter == 0: + best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_transformer(args.out_dir, val_loader, net, trans_encoder, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model=clip_model, eval_wrapper=eval_wrapper) + + if nb_iter == args.total_iter: + msg_final = f"Train. Iter {best_iter} : FID. {best_fid:.5f}, Diversity. {best_div:.4f}, TOP1. {best_top1:.4f}, TOP2. {best_top2:.4f}, TOP3. {best_top3:.4f}" + logger.info(msg_final) + break \ No newline at end of file diff --git a/VQ-Trans/train_vq.py b/VQ-Trans/train_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..d89b9930ba1262747542df3d5b2f03f8fab1b04a --- /dev/null +++ b/VQ-Trans/train_vq.py @@ -0,0 +1,171 @@ +import os +import json + +import torch +import torch.optim as optim +from torch.utils.tensorboard import SummaryWriter + +import models.vqvae as vqvae +import utils.losses as losses +import options.option_vq as option_vq +import utils.utils_model as utils_model +from dataset import dataset_VQ, dataset_TM_eval +import utils.eval_trans as eval_trans +from options.get_eval_option import get_opt +from models.evaluator_wrapper import EvaluatorModelWrapper +import warnings +warnings.filterwarnings('ignore') +from utils.word_vectorizer import WordVectorizer + +def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr): + + current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) + for param_group in optimizer.param_groups: + param_group["lr"] = current_lr + + return optimizer, current_lr + +##### ---- Exp dirs ---- ##### +args = option_vq.get_args_parser() +torch.manual_seed(args.seed) + +args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') +os.makedirs(args.out_dir, exist_ok = True) + +##### ---- Logger ---- ##### +logger = utils_model.get_logger(args.out_dir) +writer = SummaryWriter(args.out_dir) +logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) + + + +w_vectorizer = WordVectorizer('./glove', 'our_vab') + +if args.dataname == 'kit' : + dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' + args.nb_joints = 21 + +else : + dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' + args.nb_joints = 22 + +logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints') + +wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) +eval_wrapper = EvaluatorModelWrapper(wrapper_opt) + + +##### ---- Dataloader ---- ##### +train_loader = dataset_VQ.DATALoader(args.dataname, + args.batch_size, + window_size=args.window_size, + unit_length=2**args.down_t) + +train_loader_iter = dataset_VQ.cycle(train_loader) + +val_loader = dataset_TM_eval.DATALoader(args.dataname, False, + 32, + w_vectorizer, + unit_length=2**args.down_t) + +##### ---- Network ---- ##### +net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers + args.nb_code, + args.code_dim, + args.output_emb_width, + args.down_t, + args.stride_t, + args.width, + args.depth, + args.dilation_growth_rate, + args.vq_act, + args.vq_norm) + + +if args.resume_pth : + logger.info('loading checkpoint from {}'.format(args.resume_pth)) + ckpt = torch.load(args.resume_pth, map_location='cpu') + net.load_state_dict(ckpt['net'], strict=True) +net.train() +net.cuda() + +##### ---- Optimizer & Scheduler ---- ##### +optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma) + + +Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints) + +##### ------ warm-up ------- ##### +avg_recons, avg_perplexity, avg_commit = 0., 0., 0. + +for nb_iter in range(1, args.warm_up_iter): + + optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr) + + gt_motion = next(train_loader_iter) + gt_motion = gt_motion.cuda().float() # (bs, 64, dim) + + pred_motion, loss_commit, perplexity = net(gt_motion) + loss_motion = Loss(pred_motion, gt_motion) + loss_vel = Loss.forward_vel(pred_motion, gt_motion) + + loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + avg_recons += loss_motion.item() + avg_perplexity += perplexity.item() + avg_commit += loss_commit.item() + + if nb_iter % args.print_iter == 0 : + avg_recons /= args.print_iter + avg_perplexity /= args.print_iter + avg_commit /= args.print_iter + + logger.info(f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") + + avg_recons, avg_perplexity, avg_commit = 0., 0., 0. + +##### ---- Training ---- ##### +avg_recons, avg_perplexity, avg_commit = 0., 0., 0. +best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper) + +for nb_iter in range(1, args.total_iter + 1): + + gt_motion = next(train_loader_iter) + gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len + + pred_motion, loss_commit, perplexity = net(gt_motion) + loss_motion = Loss(pred_motion, gt_motion) + loss_vel = Loss.forward_vel(pred_motion, gt_motion) + + loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel + + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + avg_recons += loss_motion.item() + avg_perplexity += perplexity.item() + avg_commit += loss_commit.item() + + if nb_iter % args.print_iter == 0 : + avg_recons /= args.print_iter + avg_perplexity /= args.print_iter + avg_commit /= args.print_iter + + writer.add_scalar('./Train/L1', avg_recons, nb_iter) + writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter) + writer.add_scalar('./Train/Commit', avg_commit, nb_iter) + + logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") + + avg_recons, avg_perplexity, avg_commit = 0., 0., 0., + + if nb_iter % args.eval_iter==0 : + best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper) + \ No newline at end of file diff --git a/VQ-Trans/utils/config.py b/VQ-Trans/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..091d790e963959c326917688ee267e6a4ec136d1 --- /dev/null +++ b/VQ-Trans/utils/config.py @@ -0,0 +1,17 @@ +import os + +SMPL_DATA_PATH = "./body_models/smpl" + +SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") +SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") +JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') + +ROT_CONVENTION_TO_ROT_NUMBER = { + 'legacy': 23, + 'no_hands': 21, + 'full_hands': 51, + 'mitten_hands': 33, +} + +GENDERS = ['neutral', 'male', 'female'] +NUM_BETAS = 10 \ No newline at end of file diff --git a/VQ-Trans/utils/eval_trans.py b/VQ-Trans/utils/eval_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..8778bb8cb7e7a320e5f7f2f3b43c7ba0b4c285ab --- /dev/null +++ b/VQ-Trans/utils/eval_trans.py @@ -0,0 +1,580 @@ +import os + +import clip +import numpy as np +import torch +from scipy import linalg + +import visualization.plot_3d_global as plot_3d +from utils.motion_process import recover_from_ric + + +def tensorborad_add_video_xyz(writer, xyz, nb_iter, tag, nb_vis=4, title_batch=None, outname=None): + xyz = xyz[:1] + bs, seq = xyz.shape[:2] + xyz = xyz.reshape(bs, seq, -1, 3) + plot_xyz = plot_3d.draw_to_batch(xyz.cpu().numpy(),title_batch, outname) + plot_xyz =np.transpose(plot_xyz, (0, 1, 4, 2, 3)) + writer.add_video(tag, plot_xyz, nb_iter, fps = 20) + +@torch.no_grad() +def evaluation_vqvae(out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) : + net.eval() + nb_sample = 0 + + draw_org = [] + draw_pred = [] + draw_text = [] + + + motion_annotation_list = [] + motion_pred_list = [] + + R_precision_real = 0 + R_precision = 0 + + nb_sample = 0 + matching_score_real = 0 + matching_score_pred = 0 + for batch in val_loader: + word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token, name = batch + + motion = motion.cuda() + et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length) + bs, seq = motion.shape[0], motion.shape[1] + + num_joints = 21 if motion.shape[-1] == 251 else 22 + + pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() + + for i in range(bs): + pose = val_loader.dataset.inv_transform(motion[i:i+1, :m_length[i], :].detach().cpu().numpy()) + pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints) + + + pred_pose, loss_commit, perplexity = net(motion[i:i+1, :m_length[i]]) + pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy()) + pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints) + + if savenpy: + np.save(os.path.join(out_dir, name[i]+'_gt.npy'), pose_xyz[:, :m_length[i]].cpu().numpy()) + np.save(os.path.join(out_dir, name[i]+'_pred.npy'), pred_xyz.detach().cpu().numpy()) + + pred_pose_eval[i:i+1,:m_length[i],:] = pred_pose + + if i < min(4, bs): + draw_org.append(pose_xyz) + draw_pred.append(pred_xyz) + draw_text.append(caption[i]) + + et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, m_length) + + motion_pred_list.append(em_pred) + motion_annotation_list.append(em) + + temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) + R_precision_real += temp_R + matching_score_real += temp_match + temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) + R_precision += temp_R + matching_score_pred += temp_match + + nb_sample += bs + + motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() + motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() + gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) + mu, cov= calculate_activation_statistics(motion_pred_np) + + diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) + diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) + + R_precision_real = R_precision_real / nb_sample + R_precision = R_precision / nb_sample + + matching_score_real = matching_score_real / nb_sample + matching_score_pred = matching_score_pred / nb_sample + + fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) + + msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}" + logger.info(msg) + + if draw: + writer.add_scalar('./Test/FID', fid, nb_iter) + writer.add_scalar('./Test/Diversity', diversity, nb_iter) + writer.add_scalar('./Test/top1', R_precision[0], nb_iter) + writer.add_scalar('./Test/top2', R_precision[1], nb_iter) + writer.add_scalar('./Test/top3', R_precision[2], nb_iter) + writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter) + + + if nb_iter % 5000 == 0 : + for ii in range(4): + tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None) + + if nb_iter % 5000 == 0 : + for ii in range(4): + tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None) + + + if fid < best_fid : + msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!" + logger.info(msg) + best_fid, best_iter = fid, nb_iter + if save: + torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_fid.pth')) + + if abs(diversity_real - diversity) < abs(diversity_real - best_div) : + msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!" + logger.info(msg) + best_div = diversity + if save: + torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_div.pth')) + + if R_precision[0] > best_top1 : + msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!" + logger.info(msg) + best_top1 = R_precision[0] + if save: + torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_top1.pth')) + + if R_precision[1] > best_top2 : + msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!" + logger.info(msg) + best_top2 = R_precision[1] + + if R_precision[2] > best_top3 : + msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!" + logger.info(msg) + best_top3 = R_precision[2] + + if matching_score_pred < best_matching : + msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!" + logger.info(msg) + best_matching = matching_score_pred + if save: + torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_matching.pth')) + + if save: + torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_last.pth')) + + net.train() + return best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger + + +@torch.no_grad() +def evaluation_transformer(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model, eval_wrapper, draw = True, save = True, savegif=False) : + + trans.eval() + nb_sample = 0 + + draw_org = [] + draw_pred = [] + draw_text = [] + draw_text_pred = [] + + motion_annotation_list = [] + motion_pred_list = [] + R_precision_real = 0 + R_precision = 0 + matching_score_real = 0 + matching_score_pred = 0 + + nb_sample = 0 + for i in range(1): + for batch in val_loader: + word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch + + bs, seq = pose.shape[:2] + num_joints = 21 if pose.shape[-1] == 251 else 22 + + text = clip.tokenize(clip_text, truncate=True).cuda() + + feat_clip_text = clip_model.encode_text(text).float() + pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda() + pred_len = torch.ones(bs).long() + + for k in range(bs): + try: + index_motion = trans.sample(feat_clip_text[k:k+1], False) + except: + index_motion = torch.ones(1,1).cuda().long() + + pred_pose = net.forward_decoder(index_motion) + cur_len = pred_pose.shape[1] + + pred_len[k] = min(cur_len, seq) + pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq] + + if draw: + pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy()) + pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints) + + if i == 0 and k < 4: + draw_pred.append(pred_xyz) + draw_text_pred.append(clip_text[k]) + + et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len) + + if i == 0: + pose = pose.cuda().float() + + et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) + motion_annotation_list.append(em) + motion_pred_list.append(em_pred) + + if draw: + pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy()) + pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints) + + + for j in range(min(4, bs)): + draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0)) + draw_text.append(clip_text[j]) + + temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) + R_precision_real += temp_R + matching_score_real += temp_match + temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) + R_precision += temp_R + matching_score_pred += temp_match + + nb_sample += bs + + motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() + motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() + gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) + mu, cov= calculate_activation_statistics(motion_pred_np) + + diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) + diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) + + R_precision_real = R_precision_real / nb_sample + R_precision = R_precision / nb_sample + + matching_score_real = matching_score_real / nb_sample + matching_score_pred = matching_score_pred / nb_sample + + + fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) + + msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}" + logger.info(msg) + + + if draw: + writer.add_scalar('./Test/FID', fid, nb_iter) + writer.add_scalar('./Test/Diversity', diversity, nb_iter) + writer.add_scalar('./Test/top1', R_precision[0], nb_iter) + writer.add_scalar('./Test/top2', R_precision[1], nb_iter) + writer.add_scalar('./Test/top3', R_precision[2], nb_iter) + writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter) + + + if nb_iter % 10000 == 0 : + for ii in range(4): + tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None) + + if nb_iter % 10000 == 0 : + for ii in range(4): + tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None) + + + if fid < best_fid : + msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!" + logger.info(msg) + best_fid, best_iter = fid, nb_iter + if save: + torch.save({'trans' : trans.state_dict()}, os.path.join(out_dir, 'net_best_fid.pth')) + + if matching_score_pred < best_matching : + msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!" + logger.info(msg) + best_matching = matching_score_pred + + if abs(diversity_real - diversity) < abs(diversity_real - best_div) : + msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!" + logger.info(msg) + best_div = diversity + + if R_precision[0] > best_top1 : + msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!" + logger.info(msg) + best_top1 = R_precision[0] + + if R_precision[1] > best_top2 : + msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!" + logger.info(msg) + best_top2 = R_precision[1] + + if R_precision[2] > best_top3 : + msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!" + logger.info(msg) + best_top3 = R_precision[2] + + if save: + torch.save({'trans' : trans.state_dict()}, os.path.join(out_dir, 'net_last.pth')) + + trans.train() + return best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger + + +@torch.no_grad() +def evaluation_transformer_test(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, clip_model, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) : + + trans.eval() + nb_sample = 0 + + draw_org = [] + draw_pred = [] + draw_text = [] + draw_text_pred = [] + draw_name = [] + + motion_annotation_list = [] + motion_pred_list = [] + motion_multimodality = [] + R_precision_real = 0 + R_precision = 0 + matching_score_real = 0 + matching_score_pred = 0 + + nb_sample = 0 + + for batch in val_loader: + + word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch + bs, seq = pose.shape[:2] + num_joints = 21 if pose.shape[-1] == 251 else 22 + + text = clip.tokenize(clip_text, truncate=True).cuda() + + feat_clip_text = clip_model.encode_text(text).float() + motion_multimodality_batch = [] + for i in range(30): + pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda() + pred_len = torch.ones(bs).long() + + for k in range(bs): + try: + index_motion = trans.sample(feat_clip_text[k:k+1], True) + except: + index_motion = torch.ones(1,1).cuda().long() + + pred_pose = net.forward_decoder(index_motion) + cur_len = pred_pose.shape[1] + + pred_len[k] = min(cur_len, seq) + pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq] + + if i == 0 and (draw or savenpy): + pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy()) + pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints) + + if savenpy: + np.save(os.path.join(out_dir, name[k]+'_pred.npy'), pred_xyz.detach().cpu().numpy()) + + if draw: + if i == 0: + draw_pred.append(pred_xyz) + draw_text_pred.append(clip_text[k]) + draw_name.append(name[k]) + + et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len) + + motion_multimodality_batch.append(em_pred.reshape(bs, 1, -1)) + + if i == 0: + pose = pose.cuda().float() + + et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length) + motion_annotation_list.append(em) + motion_pred_list.append(em_pred) + + if draw or savenpy: + pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy()) + pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints) + + if savenpy: + for j in range(bs): + np.save(os.path.join(out_dir, name[j]+'_gt.npy'), pose_xyz[j][:m_length[j]].unsqueeze(0).cpu().numpy()) + + if draw: + for j in range(bs): + draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0)) + draw_text.append(clip_text[j]) + + temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True) + R_precision_real += temp_R + matching_score_real += temp_match + temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True) + R_precision += temp_R + matching_score_pred += temp_match + + nb_sample += bs + + motion_multimodality.append(torch.cat(motion_multimodality_batch, dim=1)) + + motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy() + motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy() + gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np) + mu, cov= calculate_activation_statistics(motion_pred_np) + + diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100) + diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100) + + R_precision_real = R_precision_real / nb_sample + R_precision = R_precision / nb_sample + + matching_score_real = matching_score_real / nb_sample + matching_score_pred = matching_score_pred / nb_sample + + multimodality = 0 + motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy() + multimodality = calculate_multimodality(motion_multimodality, 10) + + fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) + + msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}, multimodality. {multimodality:.4f}" + logger.info(msg) + + + if draw: + for ii in range(len(draw_org)): + tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_org', nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_gt.gif')] if savegif else None) + + tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_pred', nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_pred.gif')] if savegif else None) + + trans.train() + return fid, best_iter, diversity, R_precision[0], R_precision[1], R_precision[2], matching_score_pred, multimodality, writer, logger + +# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train +def euclidean_distance_matrix(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) + d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) + dists = np.sqrt(d1 + d2 + d3) # broadcasting + return dists + + + +def calculate_top_k(mat, top_k): + size = mat.shape[0] + gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) + bool_mat = (mat == gt_mat) + correct_vec = False + top_k_list = [] + for i in range(top_k): +# print(correct_vec, bool_mat[:, i]) + correct_vec = (correct_vec | bool_mat[:, i]) + # print(correct_vec) + top_k_list.append(correct_vec[:, None]) + top_k_mat = np.concatenate(top_k_list, axis=1) + return top_k_mat + + +def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): + dist_mat = euclidean_distance_matrix(embedding1, embedding2) + matching_score = dist_mat.trace() + argmax = np.argsort(dist_mat, axis=1) + top_k_mat = calculate_top_k(argmax, top_k) + if sum_all: + return top_k_mat.sum(axis=0), matching_score + else: + return top_k_mat, matching_score + +def calculate_multimodality(activation, multimodality_times): + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) + return dist.mean() + + +def calculate_diversity(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, diversity_times, replace=False) + second_indices = np.random.choice(num_samples, diversity_times, replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) + return dist.mean() + + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + + +def calculate_activation_statistics(activations): + + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +def calculate_frechet_feature_distance(feature_list1, feature_list2): + feature_list1 = np.stack(feature_list1) + feature_list2 = np.stack(feature_list2) + + # normalize the scale + mean = np.mean(feature_list1, axis=0) + std = np.std(feature_list1, axis=0) + 1e-10 + feature_list1 = (feature_list1 - mean) / std + feature_list2 = (feature_list2 - mean) / std + + dist = calculate_frechet_distance( + mu1=np.mean(feature_list1, axis=0), + sigma1=np.cov(feature_list1, rowvar=False), + mu2=np.mean(feature_list2, axis=0), + sigma2=np.cov(feature_list2, rowvar=False), + ) + return dist \ No newline at end of file diff --git a/VQ-Trans/utils/losses.py b/VQ-Trans/utils/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..1998161032731fc2c3edae701700679c00fd00d0 --- /dev/null +++ b/VQ-Trans/utils/losses.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +class ReConsLoss(nn.Module): + def __init__(self, recons_loss, nb_joints): + super(ReConsLoss, self).__init__() + + if recons_loss == 'l1': + self.Loss = torch.nn.L1Loss() + elif recons_loss == 'l2' : + self.Loss = torch.nn.MSELoss() + elif recons_loss == 'l1_smooth' : + self.Loss = torch.nn.SmoothL1Loss() + + # 4 global motion associated to root + # 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d) + # 3 global vel xyz + # 4 foot contact + self.nb_joints = nb_joints + self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4 + + def forward(self, motion_pred, motion_gt) : + loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim]) + return loss + + def forward_vel(self, motion_pred, motion_gt) : + loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4]) + return loss + + \ No newline at end of file diff --git a/VQ-Trans/utils/motion_process.py b/VQ-Trans/utils/motion_process.py new file mode 100644 index 0000000000000000000000000000000000000000..7819c8b3cc61b6e48c65d1a456342119060696ea --- /dev/null +++ b/VQ-Trans/utils/motion_process.py @@ -0,0 +1,59 @@ +import torch +from utils.quaternion import quaternion_to_cont6d, qrot, qinv + +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_rot(data, joints_num, skeleton): + r_rot_quat, r_pos = recover_root_rot_pos(data) + + r_rot_cont6d = quaternion_to_cont6d(r_rot_quat) + + start_indx = 1 + 2 + 1 + (joints_num - 1) * 3 + end_indx = start_indx + (joints_num - 1) * 6 + cont6d_params = data[..., start_indx:end_indx] + # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape) + cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1) + cont6d_params = cont6d_params.view(-1, joints_num, 6) + + positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos) + + return positions + + +def recover_from_ric(data, joints_num): + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions + \ No newline at end of file diff --git a/VQ-Trans/utils/paramUtil.py b/VQ-Trans/utils/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f1708b85ca80a9051cb3675cec9b999a0d0e2b --- /dev/null +++ b/VQ-Trans/utils/paramUtil.py @@ -0,0 +1,63 @@ +import numpy as np + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' + diff --git a/VQ-Trans/utils/quaternion.py b/VQ-Trans/utils/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..e2daa00aef1df60e43775864d1dd3d551f89ded8 --- /dev/null +++ b/VQ-Trans/utils/quaternion.py @@ -0,0 +1,423 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import numpy as np + +_EPS4 = np.finfo(float).eps * 4.0 + +_FLOAT_EPS = np.finfo(np.float).eps + +# PyTorch-backed implementations +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qinv_np(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return qinv(torch.from_numpy(q).float()).numpy() + + +def qnormalize(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return q / torch.norm(q, dim=-1, keepdim=True) + + +def qmul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def qeuler(q, order, epsilon=0, deg=True): + """ + Convert quaternion(s) q to Euler angles. + Expects a tensor of shape (*, 4), where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + + original_shape = list(q.shape) + original_shape[-1] = 3 + q = q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == 'zxy': + x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == 'yxz': + x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + if deg: + return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi + else: + return torch.stack((x, y, z), dim=1).view(original_shape) + + +# Numpy-backed implementations + +def qmul_np(q, r): + q = torch.from_numpy(q).contiguous().float() + r = torch.from_numpy(r).contiguous().float() + return qmul(q, r).numpy() + + +def qrot_np(q, v): + q = torch.from_numpy(q).contiguous().float() + v = torch.from_numpy(v).contiguous().float() + return qrot(q, v).numpy() + + +def qeuler_np(q, order, epsilon=0, use_gpu=False): + if use_gpu: + q = torch.from_numpy(q).cuda().float() + return qeuler(q, order, epsilon).cpu().numpy() + else: + q = torch.from_numpy(q).contiguous().float() + return qeuler(q, order, epsilon).numpy() + + +def qfix(q): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.copy() + dot_products = np.sum(q[1:] * q[:-1], axis=2) + mask = dot_products < 0 + mask = (np.cumsum(mask, axis=0) % 2).astype(bool) + result[1:][mask] *= -1 + return result + + +def euler2quat(e, order, deg=True): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.view(-1, 3) + + ## if euler angles in degrees + if deg: + e = e * np.pi / 180. + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) + ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) + rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.view(original_shape) + + +def expmap_to_quaternion(e): + """ + Convert axis-angle rotations (aka exponential maps) to quaternions. + Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". + Expects a tensor of shape (*, 3), where * denotes any number of dimensions. + Returns a tensor of shape (*, 4). + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + e = e.reshape(-1, 3) + + theta = np.linalg.norm(e, axis=1).reshape(-1, 1) + w = np.cos(0.5 * theta).reshape(-1, 1) + xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e + return np.concatenate((w, xyz), axis=1).reshape(original_shape) + + +def euler_to_quaternion(e, order): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) + ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) + rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul_np(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_np(quaternions): + q = torch.from_numpy(quaternions).contiguous().float() + return quaternion_to_matrix(q).numpy() + + +def quaternion_to_cont6d_np(quaternions): + rotation_mat = quaternion_to_matrix_np(quaternions) + cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) + return cont_6d + + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d): + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def cont6d_to_matrix_np(cont6d): + q = torch.from_numpy(cont6d).contiguous().float() + return cont6d_to_matrix(q).numpy() + + +def qpow(q0, t, dtype=torch.float): + ''' q0 : tensor of quaternions + t: tensor of powers + ''' + q0 = qnormalize(q0) + theta0 = torch.acos(q0[..., 0]) + + ## if theta0 is close to zero, add epsilon to avoid NaNs + mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) + theta0 = (1 - mask) * theta0 + mask * 10e-10 + v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) + + if isinstance(t, torch.Tensor): + q = torch.zeros(t.shape + q0.shape) + theta = t.view(-1, 1) * theta0.view(1, -1) + else: ## if t is a number + q = torch.zeros(q0.shape) + theta = t * theta0 + + q[..., 0] = torch.cos(theta) + q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) + + return q.to(dtype) + + +def qslerp(q0, q1, t): + ''' + q0: starting quaternion + q1: ending quaternion + t: array of points along the way + + Returns: + Tensor of Slerps: t.shape + q0.shape + ''' + + q0 = qnormalize(q0) + q1 = qnormalize(q1) + q_ = qpow(qmul(q1, qinv(q0)), t) + + return qmul(q_, + q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) + + +def qbetween(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v = torch.cross(v0, v1) + w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, + keepdim=True) + return qnormalize(torch.cat([w, v], dim=-1)) + + +def qbetween_np(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v0 = torch.from_numpy(v0).float() + v1 = torch.from_numpy(v1).float() + return qbetween(v0, v1).numpy() + + +def lerp(p0, p1, t): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]) + + new_shape = t.shape + p0.shape + new_view_t = t.shape + torch.Size([1] * len(p0.shape)) + new_view_p = torch.Size([1] * len(t.shape)) + p0.shape + p0 = p0.view(new_view_p).expand(new_shape) + p1 = p1.view(new_view_p).expand(new_shape) + t = t.view(new_view_t).expand(new_shape) + + return p0 + t * (p1 - p0) diff --git a/VQ-Trans/utils/rotation_conversions.py b/VQ-Trans/utils/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..1006e8a3117b231a7a456d5b826e76347fe0bfd4 --- /dev/null +++ b/VQ-Trans/utils/rotation_conversions.py @@ -0,0 +1,532 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Check PYTORCH3D_LICENCE before use + +import functools +from typing import Optional + +import torch +import torch.nn.functional as F + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) +This matrix can be applied to column vectors by post multiplication +by the points e.g. + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a, b): + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix): + """ + Convert rotations given as rotation matrices to quaternions. + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + m00 = matrix[..., 0, 0] + m11 = matrix[..., 1, 1] + m22 = matrix[..., 2, 2] + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) + o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) + o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) + o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) + return torch.stack((o0, o1, o2, o3), -1) + + +def _axis_angle_rotation(axis: str, angle): + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + if axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + if axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles, convention: str): + """ + Convert rotations given as Euler angles in radians to rotation matrices. + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) + return functools.reduce(torch.matmul, matrices) + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +): + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str): + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + + +def matrix_to_euler_angles(matrix, convention: str): + """ + Convert rotations given as rotation matrices to Euler angles in radians. + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def random_quaternions( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + Returns: + Quaternions as tensor of shape (N, 4). + """ + o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations( + n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate random rotations as 3x3 rotation matrices. + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + requires_grad: Whether the resulting tensor should have the gradient + flag set. + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions( + n, dtype=dtype, device=device, requires_grad=requires_grad + ) + return quaternion_to_matrix(quaternions) + + +def random_rotation( + dtype: Optional[torch.dtype] = None, device=None, requires_grad=False +): + """ + Generate a single random 3x3 rotation matrix. + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + requires_grad: Whether the resulting tensor should have the gradient + flag set + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device, requires_grad)[0] + + +def standardize_quaternion(quaternions): + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a, b): + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a, b): + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion): + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion, point): + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, f{point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle): + """ + Convert rotations given as axis/angle to rotation matrices. + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix): + """ + Convert rotations given as rotation matrices to axis/angle. + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle): + """ + Convert rotations given as axis/angle to quaternions. + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = 0.5 * angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +def quaternion_to_axis_angle(quaternions): + """ + Convert rotations given as quaternions to axis/angle. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + Returns: + batch of rotation matrices of size (*, 3, 3) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + Returns: + 6D rotation representation, of size (*, 6) + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) + +def canonicalize_smplh(poses, trans = None): + bs, nframes, njoints = poses.shape[:3] + + global_orient = poses[:, :, 0] + + # first global rotations + rot2d = matrix_to_axis_angle(global_orient[:, 0]) + #rot2d[:, :2] = 0 # Remove the rotation along the vertical axis + rot2d = axis_angle_to_matrix(rot2d) + + # Rotate the global rotation to eliminate Z rotations + global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient) + + # Construct canonicalized version of x + xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2) + + if trans is not None: + vel = trans[:, 1:] - trans[:, :-1] + # Turn the translation as well + vel = torch.einsum("ikj,ilk->ilj", rot2d, vel) + trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device), + torch.cumsum(vel, 1)), 1) + return xc, trans + else: + return xc + + \ No newline at end of file diff --git a/VQ-Trans/utils/skeleton.py b/VQ-Trans/utils/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..6de56af0c29ae7cccbd7178f912459413f87c646 --- /dev/null +++ b/VQ-Trans/utils/skeleton.py @@ -0,0 +1,199 @@ +from utils.quaternion import * +import scipy.ndimage.filters as filters + +class Skeleton(object): + def __init__(self, offset, kinematic_tree, device): + self.device = device + self._raw_offset_np = offset.numpy() + self._raw_offset = offset.clone().detach().to(device).float() + self._kinematic_tree = kinematic_tree + self._offset = None + self._parents = [0] * len(self._raw_offset) + self._parents[0] = -1 + for chain in self._kinematic_tree: + for j in range(1, len(chain)): + self._parents[chain[j]] = chain[j-1] + + def njoints(self): + return len(self._raw_offset) + + def offset(self): + return self._offset + + def set_offset(self, offsets): + self._offset = offsets.clone().detach().to(self.device).float() + + def kinematic_tree(self): + return self._kinematic_tree + + def parents(self): + return self._parents + + # joints (batch_size, joints_num, 3) + def get_offsets_joints_batch(self, joints): + assert len(joints.shape) == 3 + _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() + for i in range(1, self._raw_offset.shape[0]): + _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] + + self._offset = _offsets.detach() + return _offsets + + # joints (joints_num, 3) + def get_offsets_joints(self, joints): + assert len(joints.shape) == 2 + _offsets = self._raw_offset.clone() + for i in range(1, self._raw_offset.shape[0]): + # print(joints.shape) + _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] + + self._offset = _offsets.detach() + return _offsets + + # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder + # joints (batch_size, joints_num, 3) + def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): + assert len(face_joint_idx) == 4 + '''Get Forward Direction''' + l_hip, r_hip, sdr_r, sdr_l = face_joint_idx + across1 = joints[:, r_hip] - joints[:, l_hip] + across2 = joints[:, sdr_r] - joints[:, sdr_l] + across = across1 + across2 + across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] + # print(across1.shape, across2.shape) + + # forward (batch_size, 3) + forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + if smooth_forward: + forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') + # forward (batch_size, 3) + forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + + '''Get Root Rotation''' + target = np.array([[0,0,1]]).repeat(len(forward), axis=0) + root_quat = qbetween_np(forward, target) + + '''Inverse Kinematics''' + # quat_params (batch_size, joints_num, 4) + # print(joints.shape[:-1]) + quat_params = np.zeros(joints.shape[:-1] + (4,)) + # print(quat_params.shape) + root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + quat_params[:, 0] = root_quat + # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + for chain in self._kinematic_tree: + R = root_quat + for j in range(len(chain) - 1): + # (batch, 3) + u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) + # print(u.shape) + # (batch, 3) + v = joints[:, chain[j+1]] - joints[:, chain[j]] + v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] + # print(u.shape, v.shape) + rot_u_v = qbetween_np(u, v) + + R_loc = qmul_np(qinv_np(R), rot_u_v) + + quat_params[:,chain[j + 1], :] = R_loc + R = qmul_np(R, R_loc) + + return quat_params + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) + for i in range(1, len(chain)): + R = qmul(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] + return joints + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(quat_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) + for i in range(1, len(chain)): + R = qmul_np(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] + return joints + + def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(cont6d_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix_np(cont6d_params[:, 0]) + else: + matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) + for i in range(1, len(chain)): + matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]][..., np.newaxis] + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + # skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) + joints[..., 0, :] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix(cont6d_params[:, 0]) + else: + matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) + for i in range(1, len(chain)): + matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]].unsqueeze(-1) + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + + + + diff --git a/VQ-Trans/utils/utils_model.py b/VQ-Trans/utils/utils_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b3653a47ddb96f2ba27aae73b4eef8be904e9bf0 --- /dev/null +++ b/VQ-Trans/utils/utils_model.py @@ -0,0 +1,66 @@ +import numpy as np +import torch +import torch.optim as optim +import logging +import os +import sys + +def getCi(accLog): + + mean = np.mean(accLog) + std = np.std(accLog) + ci95 = 1.96*std/np.sqrt(len(accLog)) + + return mean, ci95 + +def get_logger(out_dir): + logger = logging.getLogger('Exp') + logger.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") + + file_path = os.path.join(out_dir, "run.log") + file_hdlr = logging.FileHandler(file_path) + file_hdlr.setFormatter(formatter) + + strm_hdlr = logging.StreamHandler(sys.stdout) + strm_hdlr.setFormatter(formatter) + + logger.addHandler(file_hdlr) + logger.addHandler(strm_hdlr) + return logger + +## Optimizer +def initial_optim(decay_option, lr, weight_decay, net, optimizer) : + + if optimizer == 'adamw' : + optimizer_adam_family = optim.AdamW + elif optimizer == 'adam' : + optimizer_adam_family = optim.Adam + if decay_option == 'all': + #optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay) + optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=weight_decay) + + elif decay_option == 'noVQ': + all_params = set(net.parameters()) + no_decay = set([net.vq_layer]) + + decay = all_params - no_decay + optimizer = optimizer_adam_family([ + {'params': list(no_decay), 'weight_decay': 0}, + {'params': list(decay), 'weight_decay' : weight_decay}], lr=lr) + + return optimizer + + +def get_motion_with_trans(motion, velocity) : + ''' + motion : torch.tensor, shape (batch_size, T, 72), with the global translation = 0 + velocity : torch.tensor, shape (batch_size, T, 3), contain the information of velocity = 0 + + ''' + trans = torch.cumsum(velocity, dim=1) + trans = trans - trans[:, :1] ## the first root is initialized at 0 (just for visualization) + trans = trans.repeat((1, 1, 21)) + motion_with_trans = motion + trans + return motion_with_trans + \ No newline at end of file diff --git a/VQ-Trans/utils/word_vectorizer.py b/VQ-Trans/utils/word_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..557ff97a9539c084167f3eca51fb50f53f33c8ea --- /dev/null +++ b/VQ-Trans/utils/word_vectorizer.py @@ -0,0 +1,99 @@ +import numpy as np +import pickle +from os.path import join as pjoin + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14, +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', + 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', + 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', + 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', + 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + def __init__(self, meta_root, prefix): + vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) + self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) + self.word2vec = {w: vectors[self.word2idx[w]] for w in words} + + def _get_pos_ohot(self, pos): + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self): + return len(self.word2vec) + + def __getitem__(self, item): + word, pos = item.split('/') + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + return word_vec, pos_vec + + +class WordVectorizerV2(WordVectorizer): + def __init__(self, meta_root, prefix): + super(WordVectorizerV2, self).__init__(meta_root, prefix) + self.idx2word = {self.word2idx[w]: w for w in self.word2idx} + + def __getitem__(self, item): + word_vec, pose_vec = super(WordVectorizerV2, self).__getitem__(item) + word, pos = item.split('/') + if word in self.word2vec: + return word_vec, pose_vec, self.word2idx[word] + else: + return word_vec, pose_vec, self.word2idx['unk'] + + def itos(self, idx): + if idx == len(self.idx2word): + return "pad" + return self.idx2word[idx] \ No newline at end of file diff --git a/VQ-Trans/visualization/plot_3d_global.py b/VQ-Trans/visualization/plot_3d_global.py new file mode 100644 index 0000000000000000000000000000000000000000..42fea4efd366397e17bc74470d72d3313ae228d8 --- /dev/null +++ b/VQ-Trans/visualization/plot_3d_global.py @@ -0,0 +1,129 @@ +import torch +import matplotlib.pyplot as plt +import numpy as np +import io +import matplotlib +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import mpl_toolkits.mplot3d.axes3d as p3 +from textwrap import wrap +import imageio + +def plot_3d_motion(args, figsize=(10, 10), fps=120, radius=4): + matplotlib.use('Agg') + + + joints, out_name, title = args + + data = joints.copy().reshape(len(joints), -1, 3) + + nb_joints = joints.shape[1] + smpl_kinetic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] if nb_joints == 21 else [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] + limits = 1000 if nb_joints == 21 else 2 + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + colors = ['red', 'blue', 'black', 'red', 'blue', + 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', + 'darkred', 'darkred', 'darkred', 'darkred', 'darkred'] + frame_number = data.shape[0] + # print(data.shape) + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + def update(index): + + def init(): + ax.set_xlim(-limits, limits) + ax.set_ylim(-limits, limits) + ax.set_zlim(0, limits) + ax.grid(b=False) + def plot_xzPlane(minx, maxx, miny, minz, maxz): + ## Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + fig = plt.figure(figsize=(480/96., 320/96.), dpi=96) if nb_joints == 21 else plt.figure(figsize=(10, 10), dpi=96) + if title is not None : + wraped_title = '\n'.join(wrap(title, 40)) + fig.suptitle(wraped_title, fontsize=16) + ax = p3.Axes3D(fig) + + init() + + ax.lines = [] + ax.collections = [] + ax.view_init(elev=110, azim=-90) + ax.dist = 7.5 + # ax = + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + # ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3) + + if index > 1: + ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]), + trajec[:index, 1] - trajec[index, 1], linewidth=1.0, + color='blue') + # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2]) + + for i, (chain, color) in enumerate(zip(smpl_kinetic_chain, colors)): + # print(color) + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + # print(trajec[:index, 0].shape) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + if out_name is not None : + plt.savefig(out_name, dpi=96) + plt.close() + + else : + io_buf = io.BytesIO() + fig.savefig(io_buf, format='raw', dpi=96) + io_buf.seek(0) + # print(fig.bbox.bounds) + arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8), + newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)) + io_buf.close() + plt.close() + return arr + + out = [] + for i in range(frame_number) : + out.append(update(i)) + out = np.stack(out, axis=0) + return torch.from_numpy(out) + + +def draw_to_batch(smpl_joints_batch, title_batch=None, outname=None) : + + batch_size = len(smpl_joints_batch) + out = [] + for i in range(batch_size) : + out.append(plot_3d_motion([smpl_joints_batch[i], None, title_batch[i] if title_batch is not None else None])) + if outname is not None: + imageio.mimsave(outname[i], np.array(out[-1]), fps=20) + out = torch.stack(out, axis=0) + return out + + + + + diff --git a/VQ-Trans/visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl b/VQ-Trans/visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7bb54c4f1e03340ad58b60485abaed1641d68d47 --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5b783c1677079397ee4bc26df5c72d73b8bb393bea41fa295b951187443daec +size 3556 diff --git a/VQ-Trans/visualize/joints2smpl/smpl_models/gmm_08.pkl b/VQ-Trans/visualize/joints2smpl/smpl_models/gmm_08.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c97a1d7ef396581e56ce74a12cc39175680ce028 --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/smpl_models/gmm_08.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1374908aae055a2afa01a2cd9a169bc6cfec1ceb7aa590e201a47b383060491 +size 839127 diff --git a/VQ-Trans/visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 b/VQ-Trans/visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 new file mode 100644 index 0000000000000000000000000000000000000000..b6ecce2a748128cfde09b219ccc74307de50bbae --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac9b474c74daec0253ed084720f662059336e976850f08a4a9a3f76d06613776 +size 4848 diff --git a/VQ-Trans/visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl b/VQ-Trans/visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl new file mode 100644 index 0000000000000000000000000000000000000000..77ce98631741ba3887d689077baf35422d39299d --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb69c10801205c9cfb5353fdeb1b9cc5ade53d14c265c3339421cdde8b9c91e7 +size 1323168 diff --git a/VQ-Trans/visualize/joints2smpl/src/config.py b/VQ-Trans/visualize/joints2smpl/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1021115a53f19974fbea3d3768c25874a4ae5d38 --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/src/config.py @@ -0,0 +1,40 @@ +import numpy as np + +# Map joints Name to SMPL joints idx +JOINT_MAP = { +'MidHip': 0, +'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, +'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, +'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22, +'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23, +'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, +'LCollar':13, 'Rcollar' :14, +'Nose':24, 'REye':26, 'LEye':26, 'REar':27, 'LEar':28, +'LHeel': 31, 'RHeel': 34, +'OP RShoulder': 17, 'OP LShoulder': 16, +'OP RHip': 2, 'OP LHip': 1, +'OP Neck': 12, +} + +full_smpl_idx = range(24) +key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] + + +AMASS_JOINT_MAP = { +'MidHip': 0, +'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, +'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, +'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, +'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, +'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, +'LCollar':13, 'Rcollar' :14, +} +amass_idx = range(22) +amass_smpl_idx = range(22) + + +SMPL_MODEL_DIR = "./body_models/" +GMM_MODEL_DIR = "./visualize/joints2smpl/smpl_models/" +SMPL_MEAN_FILE = "./visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5" +# for collsion +Part_Seg_DIR = "./visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl" \ No newline at end of file diff --git a/VQ-Trans/visualize/joints2smpl/src/customloss.py b/VQ-Trans/visualize/joints2smpl/src/customloss.py new file mode 100644 index 0000000000000000000000000000000000000000..880ab4861c58cec9faeb086e430fde7387c5cc9e --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/src/customloss.py @@ -0,0 +1,222 @@ +import torch +import torch.nn.functional as F +from visualize.joints2smpl.src import config + +# Guassian +def gmof(x, sigma): + """ + Geman-McClure error function + """ + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + +# angle prior +def angle_prior(pose): + """ + Angle prior that penalizes unnatural bending of the knees and elbows + """ + # We subtract 3 because pose does not include the global rotation of the model + return torch.exp( + pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2 + + +def perspective_projection(points, rotation, translation, + focal_length, camera_center): + """ + This function computes the perspective projection of a set of points. + Input: + points (bs, N, 3): 3D points + rotation (bs, 3, 3): Camera rotation + translation (bs, 3): Camera translation + focal_length (bs,) or scalar: Focal length + camera_center (bs, 2): Camera center + """ + batch_size = points.shape[0] + K = torch.zeros([batch_size, 3, 3], device=points.device) + K[:, 0, 0] = focal_length + K[:, 1, 1] = focal_length + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] + + +def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center, + joints_2d, joints_conf, pose_prior, + focal_length=5000, sigma=100, pose_prior_weight=4.78, + shape_prior_weight=5, angle_prior_weight=15.2, + output='sum'): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1) + + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # Weighted robust reprojection error + reprojection_error = gmof(projected_joints - joints_2d, sigma) + reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss + + if output == 'sum': + return total_loss.sum() + elif output == 'reprojection': + return reprojection_loss + + +# --- get camera fitting loss ----- +def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, + joints_2d, joints_conf, + focal_length=5000, depth_loss_weight=100): + """ + Loss function for camera optimization. + """ + # Project model joints + batch_size = model_joints.shape[0] + rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1) + projected_joints = perspective_projection(model_joints, rotation, camera_t, + focal_length, camera_center) + + # get the indexed four + op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + reprojection_error_op = (joints_2d[:, op_joints_ind] - + projected_joints[:, op_joints_ind]) ** 2 + reprojection_error_gt = (joints_2d[:, gt_joints_ind] - + projected_joints[:, gt_joints_ind]) ** 2 + + # Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections + # OpenPose joints are more reliable for this task, so we prefer to use them if possible + is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:, None, None] > 0).float() + reprojection_loss = (is_valid * reprojection_error_op + (1 - is_valid) * reprojection_error_gt).sum(dim=(1, 2)) + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2 + + total_loss = reprojection_loss + depth_loss + return total_loss.sum() + + + + # #####--- body fitiing loss ----- +def body_fitting_loss_3d(body_pose, preserve_pose, + betas, model_joints, camera_translation, + j3d, pose_prior, + joints3d_conf, + sigma=100, pose_prior_weight=4.78*1.5, + shape_prior_weight=5.0, angle_prior_weight=15.2, + joint_loss_weight=500.0, + pose_preserve_weight=0.0, + use_collision=False, + model_vertices=None, model_faces=None, + search_tree=None, pen_distance=None, filter_faces=None, + collision_loss_weight=1000 + ): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + + #joint3d_loss = (joint_loss_weight ** 2) * gmof((model_joints + camera_translation) - j3d, sigma).sum(dim=-1) + + joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma) + + joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1) + joint3d_loss = ((joint_loss_weight ** 2) * joint3d_loss_part).sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + collision_loss = 0.0 + # Calculate the loss due to interpenetration + if use_collision: + triangles = torch.index_select( + model_vertices, 1, + model_faces).view(batch_size, -1, 3, 3) + + with torch.no_grad(): + collision_idxs = search_tree(triangles) + + # Remove unwanted collisions + if filter_faces is not None: + collision_idxs = filter_faces(collision_idxs) + + if collision_idxs.ge(0).sum().item() > 0: + collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs)) + + pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1) + + # print('joint3d_loss', joint3d_loss.shape) + # print('pose_prior_loss', pose_prior_loss.shape) + # print('angle_prior_loss', angle_prior_loss.shape) + # print('shape_prior_loss', shape_prior_loss.shape) + # print('collision_loss', collision_loss) + # print('pose_preserve_loss', pose_preserve_loss.shape) + + total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss + + return total_loss.sum() + + +# #####--- get camera fitting loss ----- +def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est, + j3d, joints_category="orig", depth_loss_weight=100.0): + """ + Loss function for camera optimization. + """ + model_joints = model_joints + camera_t + # # get the indexed four + # op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder'] + # op_joints_ind = [config.JOINT_MAP[joint] for joint in op_joints] + # + # j3d_error_loss = (j3d[:, op_joints_ind] - + # model_joints[:, op_joints_ind]) ** 2 + + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + j3d_error_loss = (j3d[:, select_joints_ind] - + model_joints[:, gt_joints_ind]) ** 2 + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight**2) * (camera_t - camera_t_est)**2 + + total_loss = j3d_error_loss + depth_loss + return total_loss.sum() diff --git a/VQ-Trans/visualize/joints2smpl/src/prior.py b/VQ-Trans/visualize/joints2smpl/src/prior.py new file mode 100644 index 0000000000000000000000000000000000000000..7f13806dd1f6607507b0c7e5ad463b3fb0026be8 --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/src/prior.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import sys +import os + +import time +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +DEFAULT_DTYPE = torch.float32 + + +def create_prior(prior_type, **kwargs): + if prior_type == 'gmm': + prior = MaxMixturePrior(**kwargs) + elif prior_type == 'l2': + return L2Prior(**kwargs) + elif prior_type == 'angle': + return SMPLifyAnglePrior(**kwargs) + elif prior_type == 'none' or prior_type is None: + # Don't use any pose prior + def no_prior(*args, **kwargs): + return 0.0 + prior = no_prior + else: + raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') + return prior + + +class SMPLifyAnglePrior(nn.Module): + def __init__(self, dtype=torch.float32, **kwargs): + super(SMPLifyAnglePrior, self).__init__() + + # Indices for the roration angle of + # 55: left elbow, 90deg bend at -np.pi/2 + # 58: right elbow, 90deg bend at np.pi/2 + # 12: left knee, 90deg bend at np.pi/2 + # 15: right knee, 90deg bend at np.pi/2 + angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) + angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) + self.register_buffer('angle_prior_idxs', angle_prior_idxs) + + angle_prior_signs = np.array([1, -1, -1, -1], + dtype=np.float32 if dtype == torch.float32 + else np.float64) + angle_prior_signs = torch.tensor(angle_prior_signs, + dtype=dtype) + self.register_buffer('angle_prior_signs', angle_prior_signs) + + def forward(self, pose, with_global_pose=False): + ''' Returns the angle prior loss for the given pose + + Args: + pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle + representation of the rotations of the joints of the SMPL model. + Kwargs: + with_global_pose: Whether the pose vector also contains the global + orientation of the SMPL model. If not then the indices must be + corrected. + Returns: + A sze (B) tensor containing the angle prior loss for each element + in the batch. + ''' + angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 + return torch.exp(pose[:, angle_prior_idxs] * + self.angle_prior_signs).pow(2) + + +class L2Prior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): + super(L2Prior, self).__init__() + + def forward(self, module_input, *args): + return torch.sum(module_input.pow(2)) + + +class MaxMixturePrior(nn.Module): + + def __init__(self, prior_folder='prior', + num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, + use_merged=True, + **kwargs): + super(MaxMixturePrior, self).__init__() + + if dtype == DEFAULT_DTYPE: + np_dtype = np.float32 + elif dtype == torch.float64: + np_dtype = np.float64 + else: + print('Unknown float type {}, exiting!'.format(dtype)) + sys.exit(-1) + + self.num_gaussians = num_gaussians + self.epsilon = epsilon + self.use_merged = use_merged + gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) + + full_gmm_fn = os.path.join(prior_folder, gmm_fn) + if not os.path.exists(full_gmm_fn): + print('The path to the mixture prior "{}"'.format(full_gmm_fn) + + ' does not exist, exiting!') + sys.exit(-1) + + with open(full_gmm_fn, 'rb') as f: + gmm = pickle.load(f, encoding='latin1') + + if type(gmm) == dict: + means = gmm['means'].astype(np_dtype) + covs = gmm['covars'].astype(np_dtype) + weights = gmm['weights'].astype(np_dtype) + elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): + means = gmm.means_.astype(np_dtype) + covs = gmm.covars_.astype(np_dtype) + weights = gmm.weights_.astype(np_dtype) + else: + print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) + sys.exit(-1) + + self.register_buffer('means', torch.tensor(means, dtype=dtype)) + + self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) + + precisions = [np.linalg.inv(cov) for cov in covs] + precisions = np.stack(precisions).astype(np_dtype) + + self.register_buffer('precisions', + torch.tensor(precisions, dtype=dtype)) + + # The constant term: + sqrdets = np.array([(np.sqrt(np.linalg.det(c))) + for c in gmm['covars']]) + const = (2 * np.pi)**(69 / 2.) + + nll_weights = np.asarray(gmm['weights'] / (const * + (sqrdets / sqrdets.min()))) + nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) + self.register_buffer('nll_weights', nll_weights) + + weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) + self.register_buffer('weights', weights) + + self.register_buffer('pi_term', + torch.log(torch.tensor(2 * np.pi, dtype=dtype))) + + cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) + for cov in covs] + self.register_buffer('cov_dets', + torch.tensor(cov_dets, dtype=dtype)) + + # The dimensionality of the random variable + self.random_var_dim = self.means.shape[1] + + def get_mean(self): + ''' Returns the mean of the mixture ''' + mean_pose = torch.matmul(self.weights, self.means) + return mean_pose + + def merged_log_likelihood(self, pose, betas): + diff_from_mean = pose.unsqueeze(dim=1) - self.means + + prec_diff_prod = torch.einsum('mij,bmj->bmi', + [self.precisions, diff_from_mean]) + diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) + + curr_loglikelihood = 0.5 * diff_prec_quadratic - \ + torch.log(self.nll_weights) + # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + + # self.random_var_dim * self.pi_term + + # diff_prec_quadratic + # ) - torch.log(self.weights) + + min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) + return min_likelihood + + def log_likelihood(self, pose, betas, *args, **kwargs): + ''' Create graph operation for negative log-likelihood calculation + ''' + likelihoods = [] + + for idx in range(self.num_gaussians): + mean = self.means[idx] + prec = self.precisions[idx] + cov = self.covs[idx] + diff_from_mean = pose - mean + + curr_loglikelihood = torch.einsum('bj,ji->bi', + [diff_from_mean, prec]) + curr_loglikelihood = torch.einsum('bi,bi->b', + [curr_loglikelihood, + diff_from_mean]) + cov_term = torch.log(torch.det(cov) + self.epsilon) + curr_loglikelihood += 0.5 * (cov_term + + self.random_var_dim * + self.pi_term) + likelihoods.append(curr_loglikelihood) + + log_likelihoods = torch.stack(likelihoods, dim=1) + min_idx = torch.argmin(log_likelihoods, dim=1) + weight_component = self.nll_weights[:, min_idx] + weight_component = -torch.log(weight_component) + + return weight_component + log_likelihoods[:, min_idx] + + def forward(self, pose, betas): + if self.use_merged: + return self.merged_log_likelihood(pose, betas) + else: + return self.log_likelihood(pose, betas) \ No newline at end of file diff --git a/VQ-Trans/visualize/joints2smpl/src/smplify.py b/VQ-Trans/visualize/joints2smpl/src/smplify.py new file mode 100644 index 0000000000000000000000000000000000000000..580efef98dfdcf6e7486b7f5c5436820edfb6c4b --- /dev/null +++ b/VQ-Trans/visualize/joints2smpl/src/smplify.py @@ -0,0 +1,279 @@ +import torch +import os, sys +import pickle +import smplx +import numpy as np + +sys.path.append(os.path.dirname(__file__)) +from customloss import (camera_fitting_loss, + body_fitting_loss, + camera_fitting_loss_3d, + body_fitting_loss_3d, + ) +from prior import MaxMixturePrior +from visualize.joints2smpl.src import config + + + +@torch.no_grad() +def guess_init_3d(model_joints, + j3d, + joints_category="orig"): + """Initialize the camera translation via triangle similarity, by using the torso joints . + :param model_joints: SMPL model with pre joints + :param j3d: 25x3 array of Kinect Joints + :returns: 3D vector corresponding to the estimated camera translation + """ + # get the indexed four + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category=="orig": + joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category=="AMASS": + joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1) + init_t = sum_init_t / 4.0 + return init_t + + +# SMPLIfy 3D +class SMPLify3D(): + """Implementation of SMPLify, use 3D joints.""" + + def __init__(self, + smplxmodel, + step_size=1e-2, + batch_size=1, + num_iters=100, + use_collision=False, + use_lbfgs=True, + joints_category="orig", + device=torch.device('cuda:0'), + ): + + # Store options + self.batch_size = batch_size + self.device = device + self.step_size = step_size + + self.num_iters = num_iters + # --- choose optimizer + self.use_lbfgs = use_lbfgs + # GMM pose prior + self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, + num_gaussians=8, + dtype=torch.float32).to(device) + # collision part + self.use_collision = use_collision + if self.use_collision: + self.part_segm_fn = config.Part_Seg_DIR + + # reLoad SMPL-X model + self.smpl = smplxmodel + + self.model_faces = smplxmodel.faces_tensor.view(-1) + + # select joint joint_category + self.joints_category = joints_category + + if joints_category=="orig": + self.smpl_index = config.full_smpl_idx + self.corr_index = config.full_smpl_idx + elif joints_category=="AMASS": + self.smpl_index = config.amass_smpl_idx + self.corr_index = config.amass_idx + else: + self.smpl_index = None + self.corr_index = None + print("NO SUCH JOINTS CATEGORY!") + + # ---- get the man function here ------ + def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): + """Perform body fitting. + Input: + init_pose: SMPL pose estimate + init_betas: SMPL betas estimate + init_cam_t: Camera translation estimate + j3d: joints 3d aka keypoints + conf_3d: confidence for 3d joints + seq_ind: index of the sequence + Returns: + vertices: Vertices of optimized shape + joints: 3D joints of optimized shape + pose: SMPL pose parameters of optimized shape + betas: SMPL beta parameters of optimized shape + camera_translation: Camera translation + """ + + # # # add the mesh inter-section to avoid + search_tree = None + pen_distance = None + filter_faces = None + + if self.use_collision: + from mesh_intersection.bvh_search_tree import BVH + import mesh_intersection.loss as collisions_loss + from mesh_intersection.filter_faces import FilterFaces + + search_tree = BVH(max_collisions=8) + + pen_distance = collisions_loss.DistanceFieldPenetrationLoss( + sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) + + if self.part_segm_fn: + # Read the part segmentation + part_segm_fn = os.path.expandvars(self.part_segm_fn) + with open(part_segm_fn, 'rb') as faces_parents_file: + face_segm_data = pickle.load(faces_parents_file, encoding='latin1') + faces_segm = face_segm_data['segm'] + faces_parents = face_segm_data['parents'] + # Create the module used to filter invalid collision pairs + filter_faces = FilterFaces( + faces_segm=faces_segm, faces_parents=faces_parents, + ign_part_pairs=None).to(device=self.device) + + + # Split SMPL pose to body pose and global orientation + body_pose = init_pose[:, 3:].detach().clone() + global_orient = init_pose[:, :3].detach().clone() + betas = init_betas.detach().clone() + + # use guess 3d to get the initial + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach() + camera_translation = init_cam_t.clone() + + preserve_pose = init_pose[:, 3:].detach().clone() + # -------------Step 1: Optimize camera translation and body orientation-------- + # Optimize only camera translation and body orientation + body_pose.requires_grad = False + betas.requires_grad = False + global_orient.requires_grad = True + camera_translation.requires_grad = True + + camera_opt_params = [global_orient, camera_translation] + + if self.use_lbfgs: + camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(10): + def closure(): + camera_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + # print('model_joints', model_joints.shape) + # print('camera_translation', camera_translation.shape) + # print('init_cam_t', init_cam_t.shape) + # print('j3d', j3d.shape) + loss = camera_fitting_loss_3d(model_joints, camera_translation, + init_cam_t, j3d, self.joints_category) + loss.backward() + return loss + + camera_optimizer.step(closure) + else: + camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(20): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation, + init_cam_t, j3d[:, self.corr_index], self.joints_category) + camera_optimizer.zero_grad() + loss.backward() + camera_optimizer.step() + + # Fix camera translation after optimizing camera + # --------Step 2: Optimize body joints -------------------------- + # Optimize only the body pose and global orientation of the body + body_pose.requires_grad = True + global_orient.requires_grad = True + camera_translation.requires_grad = True + + # --- if we use the sequence, fix the shape + if seq_ind == 0: + betas.requires_grad = True + body_opt_params = [body_pose, betas, global_orient, camera_translation] + else: + betas.requires_grad = False + body_opt_params = [body_pose, global_orient, camera_translation] + + if self.use_lbfgs: + body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(self.num_iters): + def closure(): + body_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + pose_preserve_weight=5.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + loss.backward() + return loss + + body_optimizer.step(closure) + else: + body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(self.num_iters): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + body_optimizer.zero_grad() + loss.backward() + body_optimizer.step() + + # Get final loss value + with torch.no_grad(): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas, return_full_pose=True) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) + + vertices = smpl_output.vertices.detach() + joints = smpl_output.joints.detach() + pose = torch.cat([global_orient, body_pose], dim=-1).detach() + betas = betas.detach() + + return vertices, joints, pose, betas, camera_translation, final_loss diff --git a/VQ-Trans/visualize/render_mesh.py b/VQ-Trans/visualize/render_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..d44d04f551ccb4f1ffc9efb4cb1a44c407ede836 --- /dev/null +++ b/VQ-Trans/visualize/render_mesh.py @@ -0,0 +1,33 @@ +import argparse +import os +from visualize import vis_utils +import shutil +from tqdm import tqdm + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input_path", type=str, required=True, help='stick figure mp4 file to be rendered.') + parser.add_argument("--cuda", type=bool, default=True, help='') + parser.add_argument("--device", type=int, default=0, help='') + params = parser.parse_args() + + assert params.input_path.endswith('.mp4') + parsed_name = os.path.basename(params.input_path).replace('.mp4', '').replace('sample', '').replace('rep', '') + sample_i, rep_i = [int(e) for e in parsed_name.split('_')] + npy_path = os.path.join(os.path.dirname(params.input_path), 'results.npy') + out_npy_path = params.input_path.replace('.mp4', '_smpl_params.npy') + assert os.path.exists(npy_path) + results_dir = params.input_path.replace('.mp4', '_obj') + if os.path.exists(results_dir): + shutil.rmtree(results_dir) + os.makedirs(results_dir) + + npy2obj = vis_utils.npy2obj(npy_path, sample_i, rep_i, + device=params.device, cuda=params.cuda) + + print('Saving obj files to [{}]'.format(os.path.abspath(results_dir))) + for frame_i in tqdm(range(npy2obj.real_num_frames)): + npy2obj.save_obj(os.path.join(results_dir, 'frame{:03d}.obj'.format(frame_i)), frame_i) + + print('Saving SMPL params to [{}]'.format(os.path.abspath(out_npy_path))) + npy2obj.save_npy(out_npy_path) diff --git a/VQ-Trans/visualize/simplify_loc2rot.py b/VQ-Trans/visualize/simplify_loc2rot.py new file mode 100644 index 0000000000000000000000000000000000000000..5d3d4411310876033cb50d998ad64557a9c4b0c1 --- /dev/null +++ b/VQ-Trans/visualize/simplify_loc2rot.py @@ -0,0 +1,131 @@ +import numpy as np +import os +import torch +from visualize.joints2smpl.src import config +import smplx +import h5py +from visualize.joints2smpl.src.smplify import SMPLify3D +from tqdm import tqdm +import utils.rotation_conversions as geometry +import argparse + + +class joints2smpl: + + def __init__(self, num_frames, device_id, cuda=True): + self.device = torch.device("cuda:" + str(device_id) if cuda else "cpu") + # self.device = torch.device("cpu") + self.batch_size = num_frames + self.num_joints = 22 # for HumanML3D + self.joint_category = "AMASS" + self.num_smplify_iters = 150 + self.fix_foot = False + print(config.SMPL_MODEL_DIR) + smplmodel = smplx.create(config.SMPL_MODEL_DIR, + model_type="smpl", gender="neutral", ext="pkl", + batch_size=self.batch_size).to(self.device) + + # ## --- load the mean pose as original ---- + smpl_mean_file = config.SMPL_MEAN_FILE + + file = h5py.File(smpl_mean_file, 'r') + self.init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) + self.init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) + self.cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(self.device) + # + + # # #-------------initialize SMPLify + self.smplify = SMPLify3D(smplxmodel=smplmodel, + batch_size=self.batch_size, + joints_category=self.joint_category, + num_iters=self.num_smplify_iters, + device=self.device) + + + def npy2smpl(self, npy_path): + out_path = npy_path.replace('.npy', '_rot.npy') + motions = np.load(npy_path, allow_pickle=True)[None][0] + # print_batch('', motions) + n_samples = motions['motion'].shape[0] + all_thetas = [] + for sample_i in tqdm(range(n_samples)): + thetas, _ = self.joint2smpl(motions['motion'][sample_i].transpose(2, 0, 1)) # [nframes, njoints, 3] + all_thetas.append(thetas.cpu().numpy()) + motions['motion'] = np.concatenate(all_thetas, axis=0) + print('motions', motions['motion'].shape) + + print(f'Saving [{out_path}]') + np.save(out_path, motions) + exit() + + + + def joint2smpl(self, input_joints, init_params=None): + _smplify = self.smplify # if init_params is None else self.smplify_fast + pred_pose = torch.zeros(self.batch_size, 72).to(self.device) + pred_betas = torch.zeros(self.batch_size, 10).to(self.device) + pred_cam_t = torch.zeros(self.batch_size, 3).to(self.device) + keypoints_3d = torch.zeros(self.batch_size, self.num_joints, 3).to(self.device) + + # run the whole seqs + num_seqs = input_joints.shape[0] + + + # joints3d = input_joints[idx] # *1.2 #scale problem [check first] + keypoints_3d = torch.Tensor(input_joints).to(self.device).float() + + # if idx == 0: + if init_params is None: + pred_betas = self.init_mean_shape + pred_pose = self.init_mean_pose + pred_cam_t = self.cam_trans_zero + else: + pred_betas = init_params['betas'] + pred_pose = init_params['pose'] + pred_cam_t = init_params['cam'] + + if self.joint_category == "AMASS": + confidence_input = torch.ones(self.num_joints) + # make sure the foot and ankle + if self.fix_foot == True: + confidence_input[7] = 1.5 + confidence_input[8] = 1.5 + confidence_input[10] = 1.5 + confidence_input[11] = 1.5 + else: + print("Such category not settle down!") + + new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ + new_opt_cam_t, new_opt_joint_loss = _smplify( + pred_pose.detach(), + pred_betas.detach(), + pred_cam_t.detach(), + keypoints_3d, + conf_3d=confidence_input.to(self.device), + # seq_ind=idx + ) + + thetas = new_opt_pose.reshape(self.batch_size, 24, 3) + thetas = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(thetas)) # [bs, 24, 6] + root_loc = torch.tensor(keypoints_3d[:, 0]) # [bs, 3] + root_loc = torch.cat([root_loc, torch.zeros_like(root_loc)], dim=-1).unsqueeze(1) # [bs, 1, 6] + thetas = torch.cat([thetas, root_loc], dim=1).unsqueeze(0).permute(0, 2, 3, 1) # [1, 25, 6, 196] + + return thetas.clone().detach(), {'pose': new_opt_joints[0, :24].flatten().clone().detach(), 'betas': new_opt_betas.clone().detach(), 'cam': new_opt_cam_t.clone().detach()} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input_path", type=str, required=True, help='Blender file or dir with blender files') + parser.add_argument("--cuda", type=bool, default=True, help='') + parser.add_argument("--device", type=int, default=0, help='') + params = parser.parse_args() + + simplify = joints2smpl(device_id=params.device, cuda=params.cuda) + + if os.path.isfile(params.input_path) and params.input_path.endswith('.npy'): + simplify.npy2smpl(params.input_path) + elif os.path.isdir(params.input_path): + files = [os.path.join(params.input_path, f) for f in os.listdir(params.input_path) if f.endswith('.npy')] + for f in files: + simplify.npy2smpl(f) \ No newline at end of file diff --git a/VQ-Trans/visualize/vis_utils.py b/VQ-Trans/visualize/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05728b38e3d6be4bfd83324907e3fa7a3f358071 --- /dev/null +++ b/VQ-Trans/visualize/vis_utils.py @@ -0,0 +1,66 @@ +from model.rotation2xyz import Rotation2xyz +import numpy as np +from trimesh import Trimesh +import os +import torch +from visualize.simplify_loc2rot import joints2smpl + +class npy2obj: + def __init__(self, npy_path, sample_idx, rep_idx, device=0, cuda=True): + self.npy_path = npy_path + self.motions = np.load(self.npy_path, allow_pickle=True) + if self.npy_path.endswith('.npz'): + self.motions = self.motions['arr_0'] + self.motions = self.motions[None][0] + self.rot2xyz = Rotation2xyz(device='cpu') + self.faces = self.rot2xyz.smpl_model.faces + self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape + self.opt_cache = {} + self.sample_idx = sample_idx + self.total_num_samples = self.motions['num_samples'] + self.rep_idx = rep_idx + self.absl_idx = self.rep_idx*self.total_num_samples + self.sample_idx + self.num_frames = self.motions['motion'][self.absl_idx].shape[-1] + self.j2s = joints2smpl(num_frames=self.num_frames, device_id=device, cuda=cuda) + + if self.nfeats == 3: + print(f'Running SMPLify For sample [{sample_idx}], repetition [{rep_idx}], it may take a few minutes.') + motion_tensor, opt_dict = self.j2s.joint2smpl(self.motions['motion'][self.absl_idx].transpose(2, 0, 1)) # [nframes, njoints, 3] + self.motions['motion'] = motion_tensor.cpu().numpy() + elif self.nfeats == 6: + self.motions['motion'] = self.motions['motion'][[self.absl_idx]] + self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape + self.real_num_frames = self.motions['lengths'][self.absl_idx] + + self.vertices = self.rot2xyz(torch.tensor(self.motions['motion']), mask=None, + pose_rep='rot6d', translation=True, glob=True, + jointstype='vertices', + # jointstype='smpl', # for joint locations + vertstrans=True) + self.root_loc = self.motions['motion'][:, -1, :3, :].reshape(1, 1, 3, -1) + self.vertices += self.root_loc + + def get_vertices(self, sample_i, frame_i): + return self.vertices[sample_i, :, :, frame_i].squeeze().tolist() + + def get_trimesh(self, sample_i, frame_i): + return Trimesh(vertices=self.get_vertices(sample_i, frame_i), + faces=self.faces) + + def save_obj(self, save_path, frame_i): + mesh = self.get_trimesh(0, frame_i) + with open(save_path, 'w') as fw: + mesh.export(fw, 'obj') + return save_path + + def save_npy(self, save_path): + data_dict = { + 'motion': self.motions['motion'][0, :, :, :self.real_num_frames], + 'thetas': self.motions['motion'][0, :-1, :, :self.real_num_frames], + 'root_translation': self.motions['motion'][0, -1, :3, :self.real_num_frames], + 'faces': self.faces, + 'vertices': self.vertices[0, :, :, :self.real_num_frames], + 'text': self.motions['text'][0], + 'length': self.real_num_frames, + } + np.save(save_path, data_dict) diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..58c1cc635a5a4e8e6e00680a2ab5413668bdbe20 --- /dev/null +++ b/app.py @@ -0,0 +1,319 @@ +import sys +import os +import OpenGL.GL as gl +os.environ["PYOPENGL_PLATFORM"] = "egl" +os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" +os.system('pip install /home/user/app/pyrender') + +sys.argv = ['VQ-Trans/GPT_eval_multi.py'] +os.chdir('VQ-Trans') + +sys.path.append('/home/user/app/VQ-Trans') +sys.path.append('/home/user/app/pyrender') + +import options.option_transformer as option_trans +from huggingface_hub import snapshot_download +model_path = snapshot_download(repo_id="vumichien/T2M-GPT") + +args = option_trans.get_args_parser() + +args.dataname = 't2m' +args.resume_pth = f'{model_path}/VQVAE/net_last.pth' +args.resume_trans = f'{model_path}/VQTransformer_corruption05/net_best_fid.pth' +args.down_t = 2 +args.depth = 3 +args.block_size = 51 + +import clip +import torch +import numpy as np +import models.vqvae as vqvae +import models.t2m_trans as trans +from utils.motion_process import recover_from_ric +import visualization.plot_3d_global as plot_3d +from models.rotation2xyz import Rotation2xyz +import numpy as np +from trimesh import Trimesh +import gc + +import torch +from visualize.simplify_loc2rot import joints2smpl +import pyrender +# import matplotlib.pyplot as plt + +import io +import imageio +from shapely import geometry +import trimesh +from pyrender.constants import RenderFlags +import math +# import ffmpeg +# from PIL import Image +import hashlib +import gradio as gr +import moviepy.editor as mp + +## load clip model and datasets +is_cuda = torch.cuda.is_available() +device = torch.device("cuda" if is_cuda else "cpu") +print(device) +clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False, download_root='./') # Must set jit=False for training + +if is_cuda: + clip.model.convert_weights(clip_model) + +clip_model.eval() +for p in clip_model.parameters(): + p.requires_grad = False + +net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers + args.nb_code, + args.code_dim, + args.output_emb_width, + args.down_t, + args.stride_t, + args.width, + args.depth, + args.dilation_growth_rate) + + +trans_encoder = trans.Text2Motion_Transformer(num_vq=args.nb_code, + embed_dim=1024, + clip_dim=args.clip_dim, + block_size=args.block_size, + num_layers=9, + n_head=16, + drop_out_rate=args.drop_out_rate, + fc_rate=args.ff_rate) + + +print('loading checkpoint from {}'.format(args.resume_pth)) +ckpt = torch.load(args.resume_pth, map_location='cpu') +net.load_state_dict(ckpt['net'], strict=True) +net.eval() + +print('loading transformer checkpoint from {}'.format(args.resume_trans)) +ckpt = torch.load(args.resume_trans, map_location='cpu') +trans_encoder.load_state_dict(ckpt['trans'], strict=True) +trans_encoder.eval() + +mean = torch.from_numpy(np.load(f'{model_path}/meta/mean.npy')) +std = torch.from_numpy(np.load(f'{model_path}/meta/std.npy')) + +if is_cuda: + net.cuda() + trans_encoder.cuda() + mean = mean.cuda() + std = std.cuda() + +def render(motions, device_id=0, name='test_vis'): + frames, njoints, nfeats = motions.shape + MINS = motions.min(axis=0).min(axis=0) + MAXS = motions.max(axis=0).max(axis=0) + + height_offset = MINS[1] + motions[:, :, 1] -= height_offset + trajec = motions[:, 0, [0, 2]] + is_cuda = torch.cuda.is_available() + # device = torch.device("cuda" if is_cuda else "cpu") + j2s = joints2smpl(num_frames=frames, device_id=0, cuda=is_cuda) + rot2xyz = Rotation2xyz(device=device) + faces = rot2xyz.smpl_model.faces + + if not os.path.exists(f'output/{name}_pred.pt'): + print(f'Running SMPLify, it may take a few minutes.') + motion_tensor, opt_dict = j2s.joint2smpl(motions) # [nframes, njoints, 3] + + vertices = rot2xyz(torch.tensor(motion_tensor).clone(), mask=None, + pose_rep='rot6d', translation=True, glob=True, + jointstype='vertices', + vertstrans=True) + vertices = vertices.detach().cpu() + torch.save(vertices, f'output/{name}_pred.pt') + else: + vertices = torch.load(f'output/{name}_pred.pt') + frames = vertices.shape[3] # shape: 1, nb_frames, 3, nb_joints + print(vertices.shape) + MINS = torch.min(torch.min(vertices[0], axis=0)[0], axis=1)[0] + MAXS = torch.max(torch.max(vertices[0], axis=0)[0], axis=1)[0] + + out_list = [] + + minx = MINS[0] - 0.5 + maxx = MAXS[0] + 0.5 + minz = MINS[2] - 0.5 + maxz = MAXS[2] + 0.5 + polygon = geometry.Polygon([[minx, minz], [minx, maxz], [maxx, maxz], [maxx, minz]]) + polygon_mesh = trimesh.creation.extrude_polygon(polygon, 1e-5) + + vid = [] + for i in range(frames): + if i % 10 == 0: + print(i) + + mesh = Trimesh(vertices=vertices[0, :, :, i].squeeze().tolist(), faces=faces) + + base_color = (0.11, 0.53, 0.8, 0.5) + ## OPAQUE rendering without alpha + ## BLEND rendering consider alpha + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.7, + alphaMode='OPAQUE', + baseColorFactor=base_color + ) + + + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + polygon_mesh.visual.face_colors = [0, 0, 0, 0.21] + polygon_render = pyrender.Mesh.from_trimesh(polygon_mesh, smooth=False) + + bg_color = [1, 1, 1, 0.8] + scene = pyrender.Scene(bg_color=bg_color, ambient_light=(0.4, 0.4, 0.4)) + + sx, sy, tx, ty = [0.75, 0.75, 0, 0.10] + + camera = pyrender.PerspectiveCamera(yfov=(np.pi / 3.0)) + + light = pyrender.DirectionalLight(color=[1,1,1], intensity=300) + + scene.add(mesh) + + c = np.pi / 2 + + scene.add(polygon_render, pose=np.array([[ 1, 0, 0, 0], + + [ 0, np.cos(c), -np.sin(c), MINS[1].cpu().numpy()], + + [ 0, np.sin(c), np.cos(c), 0], + + [ 0, 0, 0, 1]])) + + light_pose = np.eye(4) + light_pose[:3, 3] = [0, -1, 1] + scene.add(light, pose=light_pose.copy()) + + light_pose[:3, 3] = [0, 1, 1] + scene.add(light, pose=light_pose.copy()) + + light_pose[:3, 3] = [1, 1, 2] + scene.add(light, pose=light_pose.copy()) + + + c = -np.pi / 6 + + scene.add(camera, pose=[[ 1, 0, 0, (minx+maxx).cpu().numpy()/2], + + [ 0, np.cos(c), -np.sin(c), 1.5], + + [ 0, np.sin(c), np.cos(c), max(4, minz.cpu().numpy()+(1.5-MINS[1].cpu().numpy())*2, (maxx-minx).cpu().numpy())], + + [ 0, 0, 0, 1] + ]) + + # render scene + r = pyrender.OffscreenRenderer(960, 960) + + color, _ = r.render(scene, flags=RenderFlags.RGBA) + # Image.fromarray(color).save(outdir+'/'+name+'_'+str(i)+'.png') + + vid.append(color) + + r.delete() + + out = np.stack(vid, axis=0) + imageio.mimwrite(f'output/results.gif', out, fps=20) + out_video = mp.VideoFileClip(f'output/results.gif') + out_video.write_videofile("output/results.mp4") + del out, vertices + return f'output/results.mp4' + +def predict(clip_text, method='fast'): + gc.collect() + if torch.cuda.is_available(): + text = clip.tokenize([clip_text], truncate=True).cuda() + else: + text = clip.tokenize([clip_text], truncate=True) + feat_clip_text = clip_model.encode_text(text).float() + index_motion = trans_encoder.sample(feat_clip_text[0:1], False) + pred_pose = net.forward_decoder(index_motion) + pred_xyz = recover_from_ric((pred_pose*std+mean).float(), 22) + output_name = hashlib.md5(clip_text.encode()).hexdigest() + if method == 'fast': + xyz = pred_xyz.reshape(1, -1, 22, 3) + pose_vis = plot_3d.draw_to_batch(xyz.detach().cpu().numpy(), title_batch=None, outname=[f'output/results.gif']) + out_video = mp.VideoFileClip("output/results.gif") + out_video.write_videofile("output/results.mp4") + return f'output/results.mp4' + elif method == 'slow': + output_path = render(pred_xyz.detach().cpu().numpy().squeeze(axis=0), device_id=0, name=output_name) + return output_path + + +# ---- Gradio Layout ----- +text_prompt = gr.Textbox(label="Text prompt", lines=1, interactive=True) +video_out = gr.Video(label="Motion", mirror_webcam=False, interactive=False) +demo = gr.Blocks() +demo.encrypt = False + +with demo: + gr.Markdown(''' +
+

Generating Human Motion from Textual Descriptions (T2M-GPT)

+ This space uses T2M-GPT models based on Vector Quantised-Variational AutoEncoder (VQ-VAE) and Generative Pre-trained Transformer (GPT) for human motion generation from textural descriptions🤗 +
+ ''') + with gr.Row(): + with gr.Column(): + gr.Markdown(''' +
+ Demo Slow +
a man starts off in an up right position with botg arms extended out by his sides, he then brings his arms down to his body and claps his hands together. after this he wals down amd the the left where he proceeds to sit on a seat +
+
+ ''') + with gr.Column(): + gr.Markdown(''' +
+ Demo Slow 2 +
a person puts their hands together, leans forwards slightly then swings the arms from right to left +
+
+ ''') + with gr.Column(): + gr.Markdown(''' +
+ Demo Slow 3 +
a man is practicing the waltz with a partner +
+
+ ''') + with gr.Row(): + with gr.Column(): + gr.Markdown(''' + ### Generate human motion by **T2M-GPT** + ##### Step 1. Give prompt text describing human motion + ##### Step 2. Choice method to render output (Fast: Sketch skeleton; Slow: SMPL mesh, only work with GPU and running time around 2 mins) + ##### Step 3. Generate output and enjoy + ''') + with gr.Column(): + with gr.Row(): + text_prompt.render() + method = gr.Dropdown(["slow", "fast"], label="Method", value="slow") + with gr.Row(): + generate_btn = gr.Button("Generate") + generate_btn.click(predict, [text_prompt, method], [video_out], api_name="generate") + print(video_out) + with gr.Row(): + video_out.render() + with gr.Row(): + gr.Markdown(''' + ### You can test by following examples: + ''') + examples = gr.Examples(examples= + [ "a person jogs in place, slowly at first, then increases speed. they then back up and squat down.", + "a man steps forward and does a handstand", + "a man rises from the ground, walks in a circle and sits back down on the ground"], + label="Examples", inputs=[text_prompt]) + +demo.launch(debug=True) diff --git a/pyrender/.coveragerc b/pyrender/.coveragerc new file mode 100644 index 0000000000000000000000000000000000000000..ee31cded3509cbd991a33dd27e2525b93a1a6558 --- /dev/null +++ b/pyrender/.coveragerc @@ -0,0 +1,5 @@ +[report] +exclude_lines = + def __repr__ + def __str__ + @abc.abstractmethod diff --git a/pyrender/.flake8 b/pyrender/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..fec4bcfc3ba774b53a866d839ea15bae6ebdb4a6 --- /dev/null +++ b/pyrender/.flake8 @@ -0,0 +1,8 @@ +[flake8] +ignore = E231,W504,F405,F403 +max-line-length = 79 +select = B,C,E,F,W,T4,B9 +exclude = + docs/source/conf.py, + __pycache__, + examples/* diff --git a/pyrender/.gitignore b/pyrender/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ae59dec631f71a23d4255aaf9c0274a699f4ba25 --- /dev/null +++ b/pyrender/.gitignore @@ -0,0 +1,106 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +docs/**/generated/** + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/pyrender/.pre-commit-config.yaml b/pyrender/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1817eb39bf409aff80c7d2cc79a3bc3856c70dbd --- /dev/null +++ b/pyrender/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.7.1 + hooks: + - id: flake8 + exclude: ^setup.py diff --git a/pyrender/.travis.yml b/pyrender/.travis.yml new file mode 100644 index 0000000000000000000000000000000000000000..1ad289ae1513eaf8fda74f8d5ab7840be3ef56cb --- /dev/null +++ b/pyrender/.travis.yml @@ -0,0 +1,43 @@ +language: python +sudo: required +dist: xenial + +python: +- '3.6' +- '3.7' + +before_install: + # Pre-install osmesa + - sudo apt update + - sudo wget https://github.com/mmatl/travis_debs/raw/master/xenial/mesa_18.3.3-0.deb + - sudo dpkg -i ./mesa_18.3.3-0.deb || true + - sudo apt install -f + - git clone https://github.com/mmatl/pyopengl.git + - cd pyopengl + - pip install . + - cd .. + +install: + - pip install . + # - pip install -q pytest pytest-cov coveralls + - pip install pytest pytest-cov coveralls + - pip install ./pyopengl + +script: + - PYOPENGL_PLATFORM=osmesa pytest --cov=pyrender tests + +after_success: +- coveralls || true + +deploy: + provider: pypi + skip_existing: true + user: mmatl + on: + tags: true + branch: master + password: + secure: O4WWMbTYb2eVYIO4mMOVa6/xyhX7mPvJpd96cxfNvJdyuqho8VapOhzqsI5kahMB1hFjWWr61yR4+Ru5hoDYf3XA6BQVk8eCY9+0H7qRfvoxex71lahKAqfHLMoE1xNdiVTgl+QN9hYjOnopLod24rx8I8eXfpHu/mfCpuTYGyLlNcDP5St3bXpXLPB5wg8Jo1YRRv6W/7fKoXyuWjewk9cJAS0KrEgnDnSkdwm6Pb+80B2tcbgdGvpGaByw5frndwKiMUMgVUownepDU5POQq2p29wwn9lCvRucULxjEgO+63jdbZRj5fNutLarFa2nISfYnrd72LOyDfbJubwAzzAIsy2JbFORyeHvCgloiuE9oE7a9oOQt/1QHBoIV0seiawMWn55Yp70wQ7HlJs4xSGJWCGa5+9883QRNsvj420atkb3cgO8P+PXwiwTi78Dq7Z/xHqccsU0b8poqBneQoA+pUGgNnF6V7Z8e9RsCcse2gAWSZWuOK3ua+9xCgH7I7MeL3afykr2aJ+yFCoYJMFrUjJeodMX2RbL0q+3FzIPZeGW3WdhTEAL9TSKRcJBSQTskaQlZx/OcpobxS7t3d2S68CCLG9uMTqOTYws55WZ1etalA75sRk9K2MR7ZGjZW3jdtvMViISc/t6Rrjea1GE8ZHGJC6/IeLIWA2c7nc= + distributions: sdist bdist_wheel +notifications: + email: false diff --git a/pyrender/LICENSE b/pyrender/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4276f7d204e4d85104246df637e0e36adbef14a7 --- /dev/null +++ b/pyrender/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Matthew Matl + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pyrender/MANIFEST.in b/pyrender/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..097bcca3b4fccdc39ddd63c10f710ad524898e95 --- /dev/null +++ b/pyrender/MANIFEST.in @@ -0,0 +1,5 @@ +# Include the license +include LICENSE +include README.rst +include pyrender/fonts/* +include pyrender/shaders/* diff --git a/pyrender/README.md b/pyrender/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ae88ed1c5e78f247e38291ed83cf4c81230bf976 --- /dev/null +++ b/pyrender/README.md @@ -0,0 +1,92 @@ +# Pyrender + +[![Build Status](https://travis-ci.org/mmatl/pyrender.svg?branch=master)](https://travis-ci.org/mmatl/pyrender) +[![Documentation Status](https://readthedocs.org/projects/pyrender/badge/?version=latest)](https://pyrender.readthedocs.io/en/latest/?badge=latest) +[![Coverage Status](https://coveralls.io/repos/github/mmatl/pyrender/badge.svg?branch=master)](https://coveralls.io/github/mmatl/pyrender?branch=master) +[![PyPI version](https://badge.fury.io/py/pyrender.svg)](https://badge.fury.io/py/pyrender) +[![Downloads](https://pepy.tech/badge/pyrender)](https://pepy.tech/project/pyrender) + +Pyrender is a pure Python (2.7, 3.4, 3.5, 3.6) library for physically-based +rendering and visualization. +It is designed to meet the [glTF 2.0 specification from Khronos](https://www.khronos.org/gltf/). + +Pyrender is lightweight, easy to install, and simple to use. +It comes packaged with both an intuitive scene viewer and a headache-free +offscreen renderer with support for GPU-accelerated rendering on headless +servers, which makes it perfect for machine learning applications. + +Extensive documentation, including a quickstart guide, is provided [here](https://pyrender.readthedocs.io/en/latest/). + +For a minimal working example of GPU-accelerated offscreen rendering using EGL, +check out the [EGL Google CoLab Notebook](https://colab.research.google.com/drive/1pcndwqeY8vker3bLKQNJKr3B-7-SYenE?usp=sharing). + + +

+ GIF of Viewer + Damaged Helmet +

+ +## Installation +You can install pyrender directly from pip. + +```bash +pip install pyrender +``` + +## Features + +Despite being lightweight, pyrender has lots of features, including: + +* Simple interoperation with the amazing [trimesh](https://github.com/mikedh/trimesh) project, +which enables out-of-the-box support for dozens of mesh types, including OBJ, +STL, DAE, OFF, PLY, and GLB. +* An easy-to-use scene viewer with support for animation, showing face and vertex +normals, toggling lighting conditions, and saving images and GIFs. +* An offscreen rendering module that supports OSMesa and EGL backends. +* Shadow mapping for directional and spot lights. +* Metallic-roughness materials for physically-based rendering, including several +types of texture and normal mapping. +* Transparency. +* Depth and color image generation. + +## Sample Usage + +For sample usage, check out the [quickstart +guide](https://pyrender.readthedocs.io/en/latest/examples/index.html) or one of +the Google CoLab Notebooks: + +* [EGL Google CoLab Notebook](https://colab.research.google.com/drive/1pcndwqeY8vker3bLKQNJKr3B-7-SYenE?usp=sharing) + +## Viewer Keyboard and Mouse Controls + +When using the viewer, the basic controls for moving about the scene are as follows: + +* To rotate the camera about the center of the scene, hold the left mouse button and drag the cursor. +* To rotate the camera about its viewing axis, hold `CTRL` left mouse button and drag the cursor. +* To pan the camera, do one of the following: + * Hold `SHIFT`, then hold the left mouse button and drag the cursor. + * Hold the middle mouse button and drag the cursor. +* To zoom the camera in or out, do one of the following: + * Scroll the mouse wheel. + * Hold the right mouse button and drag the cursor. + +The available keyboard commands are as follows: + +* `a`: Toggles rotational animation mode. +* `c`: Toggles backface culling. +* `f`: Toggles fullscreen mode. +* `h`: Toggles shadow rendering. +* `i`: Toggles axis display mode (no axes, world axis, mesh axes, all axes). +* `l`: Toggles lighting mode (scene lighting, Raymond lighting, or direct lighting). +* `m`: Toggles face normal visualization. +* `n`: Toggles vertex normal visualization. +* `o`: Toggles orthographic camera mode. +* `q`: Quits the viewer. +* `r`: Starts recording a GIF, and pressing again stops recording and opens a file dialog. +* `s`: Opens a file dialog to save the current view as an image. +* `w`: Toggles wireframe mode (scene default, flip wireframes, all wireframe, or all solid). +* `z`: Resets the camera to the default view. + +As a note, displaying shadows significantly slows down rendering, so if you're +experiencing low framerates, just kill shadows or reduce the number of lights in +your scene. diff --git a/pyrender/docs/Makefile b/pyrender/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..b1064a04362a0c4372fae351f99ed3bd9f82ff92 --- /dev/null +++ b/pyrender/docs/Makefile @@ -0,0 +1,23 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +clean: + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + rm -rf ./source/generated/* + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/pyrender/docs/make.bat b/pyrender/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..543c6b13b473ff3c586d5d97ae418d267ee795c4 --- /dev/null +++ b/pyrender/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/pyrender/docs/source/api/index.rst b/pyrender/docs/source/api/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..b6e473149d8f132f176e242c93406fdb84ce0b04 --- /dev/null +++ b/pyrender/docs/source/api/index.rst @@ -0,0 +1,59 @@ +Pyrender API Documentation +========================== + +Constants +--------- +.. automodapi:: pyrender.constants + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + +Cameras +------- +.. automodapi:: pyrender.camera + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + +Lighting +-------- +.. automodapi:: pyrender.light + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + +Objects +------- +.. automodapi:: pyrender + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + :skip: Camera, DirectionalLight, Light, OffscreenRenderer, Node + :skip: OrthographicCamera, PerspectiveCamera, PointLight, RenderFlags + :skip: Renderer, Scene, SpotLight, TextAlign, Viewer, GLTF + +Scenes +------ +.. automodapi:: pyrender + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: + :skip: Camera, DirectionalLight, Light, OffscreenRenderer + :skip: OrthographicCamera, PerspectiveCamera, PointLight, RenderFlags + :skip: Renderer, SpotLight, TextAlign, Viewer, Sampler, Texture, Material + :skip: MetallicRoughnessMaterial, Primitive, Mesh, GLTF + +On-Screen Viewer +---------------- +.. automodapi:: pyrender.viewer + :no-inheritance-diagram: + :no-inherited-members: + :no-main-docstr: + :no-heading: + +Off-Screen Rendering +-------------------- +.. automodapi:: pyrender.offscreen + :no-inheritance-diagram: + :no-main-docstr: + :no-heading: diff --git a/pyrender/docs/source/conf.py b/pyrender/docs/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf194c375e7e789b334a838953adfeaf2eb59b6 --- /dev/null +++ b/pyrender/docs/source/conf.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- +# +# core documentation build configuration file, created by +# sphinx-quickstart on Sun Oct 16 14:33:48 2016. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys +import os +from pyrender import __version__ +from sphinx.domains.python import PythonDomain + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath('../../')) + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.coverage', + 'sphinx.ext.githubpages', + 'sphinx.ext.intersphinx', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_automodapi.automodapi', + 'sphinx_automodapi.smart_resolver' +] +numpydoc_class_members_toctree = False +automodapi_toctreedirnm = 'generated' +automodsumm_inherited_members = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'pyrender' +copyright = u'2018, Matthew Matl' +author = u'Matthew Matl' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = __version__ +# The full version, including alpha/beta/rc tags. +release = __version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = [] + +# The reST default role (used for this markup: `text`) to use for all +# documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +#keep_warnings = False + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +import sphinx_rtd_theme +html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (relative to this directory) to use as a favicon of +# the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Add any extra paths that contain custom files (such as robots.txt or +# .htaccess) here, relative to this directory. These files are copied +# directly to the root of the documentation. +#html_extra_path = [] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = None + +# Language to be used for generating the HTML full-text search index. +# Sphinx supports the following languages: +# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' +# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' +#html_search_language = 'en' + +# A dictionary with options for the search language support, empty by default. +# Now only 'ja' uses this config value +#html_search_options = {'type': 'default'} + +# The name of a javascript file (relative to the configuration directory) that +# implements a search results scorer. If empty, the default will be used. +#html_search_scorer = 'scorer.js' + +# Output file base name for HTML help builder. +htmlhelp_basename = 'coredoc' + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +#'papersize': 'letterpaper', + +# The font size ('10pt', '11pt' or '12pt'). +#'pointsize': '10pt', + +# Additional stuff for the LaTeX preamble. +#'preamble': '', + +# Latex figure (float) alignment +#'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'pyrender.tex', u'pyrender Documentation', + u'Matthew Matl', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'pyrender', u'pyrender Documentation', + [author], 1) +] + +# If true, show URL addresses after external links. +#man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'pyrender', u'pyrender Documentation', + author, 'pyrender', 'One line description of project.', + 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +#texinfo_appendices = [] + +# If false, no module index is generated. +#texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +#texinfo_no_detailmenu = False + +intersphinx_mapping = { + 'python' : ('https://docs.python.org/', None), + 'pyrender' : ('https://pyrender.readthedocs.io/en/latest/', None), +} + +# Autosummary fix +autosummary_generate = True + +# Try to suppress multiple-definition warnings by always taking the shorter +# path when two or more paths have the same base module + +class MyPythonDomain(PythonDomain): + + def find_obj(self, env, modname, classname, name, type, searchmode=0): + """Ensures an object always resolves to the desired module + if defined there.""" + orig_matches = PythonDomain.find_obj( + self, env, modname, classname, name, type, searchmode + ) + + if len(orig_matches) <= 1: + return orig_matches + + # If multiple matches, try to take the shortest if all the modules are + # the same + first_match_name_sp = orig_matches[0][0].split('.') + base_name = first_match_name_sp[0] + min_len = len(first_match_name_sp) + best_match = orig_matches[0] + + for match in orig_matches[1:]: + match_name = match[0] + match_name_sp = match_name.split('.') + match_base = match_name_sp[0] + + # If we have mismatched bases, return them all to trigger warnings + if match_base != base_name: + return orig_matches + + # Otherwise, check and see if it's shorter + if len(match_name_sp) < min_len: + min_len = len(match_name_sp) + best_match = match + + return (best_match,) + + +def setup(sphinx): + """Use MyPythonDomain in place of PythonDomain""" + sphinx.override_domain(MyPythonDomain) + diff --git a/pyrender/docs/source/examples/cameras.rst b/pyrender/docs/source/examples/cameras.rst new file mode 100644 index 0000000000000000000000000000000000000000..39186b75b16584d11fd1606b92291c104e0452bd --- /dev/null +++ b/pyrender/docs/source/examples/cameras.rst @@ -0,0 +1,26 @@ +.. _camera_guide: + +Creating Cameras +================ + +Pyrender supports three camera types -- :class:`.PerspectiveCamera` and +:class:`.IntrinsicsCamera` types, +which render scenes as a human would see them, and +:class:`.OrthographicCamera` types, which preserve distances between points. + +Creating cameras is easy -- just specify their basic attributes: + +>>> pc = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414) +>>> oc = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) + +For more information, see the Khronos group's documentation here_: + +.. _here: https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#projection-matrices + +When you add cameras to the scene, make sure that you're using OpenGL camera +coordinates to specify their pose. See the illustration below for details. +Basically, the camera z-axis points away from the scene, the x-axis points +right in image space, and the y-axis points up in image space. + +.. image:: /_static/camera_coords.png + diff --git a/pyrender/docs/source/examples/index.rst b/pyrender/docs/source/examples/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..4be536cd62c1cca112228f4e114e783be77a0ab8 --- /dev/null +++ b/pyrender/docs/source/examples/index.rst @@ -0,0 +1,20 @@ +.. _guide: + +User Guide +========== + +This section contains guides on how to use Pyrender to quickly visualize +your 3D data, including a quickstart guide and more detailed descriptions +of each part of the rendering pipeline. + + +.. toctree:: + :maxdepth: 2 + + quickstart.rst + models.rst + lighting.rst + cameras.rst + scenes.rst + offscreen.rst + viewer.rst diff --git a/pyrender/docs/source/examples/lighting.rst b/pyrender/docs/source/examples/lighting.rst new file mode 100644 index 0000000000000000000000000000000000000000..f89bee7d15027a0f52711622b053b49cc6e1b410 --- /dev/null +++ b/pyrender/docs/source/examples/lighting.rst @@ -0,0 +1,21 @@ +.. _lighting_guide: + +Creating Lights +=============== + +Pyrender supports three types of punctual light: + +- :class:`.PointLight`: Point-based light sources, such as light bulbs. +- :class:`.SpotLight`: A conical light source, like a flashlight. +- :class:`.DirectionalLight`: A general light that does not attenuate with + distance. + +Creating lights is easy -- just specify their basic attributes: + +>>> pl = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=2.0) +>>> sl = pyrender.SpotLight(color=[1.0, 1.0, 1.0], intensity=2.0, +... innerConeAngle=0.05, outerConeAngle=0.5) +>>> dl = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0) + +For more information about how these lighting models are implemented, +see their class documentation. diff --git a/pyrender/docs/source/examples/models.rst b/pyrender/docs/source/examples/models.rst new file mode 100644 index 0000000000000000000000000000000000000000..84e71c4ff41a8d2e0eb2dc48434caedb757ff954 --- /dev/null +++ b/pyrender/docs/source/examples/models.rst @@ -0,0 +1,143 @@ +.. _model_guide: + +Loading and Configuring Models +============================== +The first step to any rendering application is loading your models. +Pyrender implements the GLTF 2.0 specification, which means that all +models are composed of a hierarchy of objects. + +At the top level, we have a :class:`.Mesh`. The :class:`.Mesh` is +basically a wrapper of any number of :class:`.Primitive` types, +which actually represent geometry that can be drawn to the screen. + +Primitives are composed of a variety of parameters, including +vertex positions, vertex normals, color and texture information, +and triangle indices if smooth rendering is desired. +They can implement point clouds, triangular meshes, or lines +depending on how you configure their data and set their +:attr:`.Primitive.mode` parameter. + +Although you can create primitives yourself if you want to, +it's probably easier to just use the utility functions provided +in the :class:`.Mesh` class. + +Creating Triangular Meshes +-------------------------- + +Simple Construction +~~~~~~~~~~~~~~~~~~~ +Pyrender allows you to create a :class:`.Mesh` containing a +triangular mesh model directly from a :class:`~trimesh.base.Trimesh` object +using the :meth:`.Mesh.from_trimesh` static method. + +>>> import trimesh +>>> import pyrender +>>> import numpy as np +>>> tm = trimesh.load('examples/models/fuze.obj') +>>> m = pyrender.Mesh.from_trimesh(tm) +>>> m.primitives +[] + +You can also create a single :class:`.Mesh` from a list of +:class:`~trimesh.base.Trimesh` objects: + +>>> tms = [trimesh.creation.icosahedron(), trimesh.creation.cylinder()] +>>> m = pyrender.Mesh.from_trimesh(tms) +[, + ] + +Vertex Smoothing +~~~~~~~~~~~~~~~~ + +The :meth:`.Mesh.from_trimesh` method has a few additional optional parameters. +If you want to render the mesh without interpolating face normals, which can +be useful for meshes that are supposed to be angular (e.g. a cube), you +can specify ``smooth=False``. + +>>> m = pyrender.Mesh.from_trimesh(tm, smooth=False) + +Per-Face or Per-Vertex Coloration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you have an untextured trimesh, you can color it in with per-face or +per-vertex colors: + +>>> tm.visual.vertex_colors = np.random.uniform(size=tm.vertices.shape) +>>> tm.visual.face_colors = np.random.uniform(size=tm.faces.shape) +>>> m = pyrender.Mesh.from_trimesh(tm) + +Instancing +~~~~~~~~~~ + +If you want to render many copies of the same mesh at different poses, +you can statically create a vast array of them in an efficient manner. +Simply specify the ``poses`` parameter to be a list of ``N`` 4x4 homogenous +transformation matrics that position the meshes relative to their common +base frame: + +>>> tfs = np.tile(np.eye(4), (3,1,1)) +>>> tfs[1,:3,3] = [0.1, 0.0, 0.0] +>>> tfs[2,:3,3] = [0.2, 0.0, 0.0] +>>> tfs +array([[[1. , 0. , 0. , 0. ], + [0. , 1. , 0. , 0. ], + [0. , 0. , 1. , 0. ], + [0. , 0. , 0. , 1. ]], + [[1. , 0. , 0. , 0.1], + [0. , 1. , 0. , 0. ], + [0. , 0. , 1. , 0. ], + [0. , 0. , 0. , 1. ]], + [[1. , 0. , 0. , 0.2], + [0. , 1. , 0. , 0. ], + [0. , 0. , 1. , 0. ], + [0. , 0. , 0. , 1. ]]]) + +>>> m = pyrender.Mesh.from_trimesh(tm, poses=tfs) + +Custom Materials +~~~~~~~~~~~~~~~~ + +You can also specify a custom material for any triangular mesh you create +in the ``material`` parameter of :meth:`.Mesh.from_trimesh`. +The main material supported by Pyrender is the +:class:`.MetallicRoughnessMaterial`. +The metallic-roughness model supports rendering highly-realistic objects across +a wide gamut of materials. + +For more information, see the documentation of the +:class:`.MetallicRoughnessMaterial` constructor or look at the Khronos_ +documentation for more information. + +.. _Khronos: https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#materials + +Creating Point Clouds +--------------------- + +Point Sprites +~~~~~~~~~~~~~ +Pyrender also allows you to create a :class:`.Mesh` containing a +point cloud directly from :class:`numpy.ndarray` instances +using the :meth:`.Mesh.from_points` static method. + +Simply provide a list of points and optional per-point colors and normals. + +>>> pts = tm.vertices.copy() +>>> colors = np.random.uniform(size=pts.shape) +>>> m = pyrender.Mesh.from_points(pts, colors=colors) + +Point clouds created in this way will be rendered as square point sprites. + +.. image:: /_static/points.png + +Point Spheres +~~~~~~~~~~~~~ +If you have a monochromatic point cloud and would like to render it with +spheres, you can render it by instancing a spherical trimesh: + +>>> sm = trimesh.creation.uv_sphere(radius=0.1) +>>> sm.visual.vertex_colors = [1.0, 0.0, 0.0] +>>> tfs = np.tile(np.eye(4), (len(pts), 1, 1)) +>>> tfs[:,:3,3] = pts +>>> m = pyrender.Mesh.from_trimesh(sm, poses=tfs) + +.. image:: /_static/points2.png diff --git a/pyrender/docs/source/examples/offscreen.rst b/pyrender/docs/source/examples/offscreen.rst new file mode 100644 index 0000000000000000000000000000000000000000..291532b6e0c0e512df35a97e3c826cc83015aeca --- /dev/null +++ b/pyrender/docs/source/examples/offscreen.rst @@ -0,0 +1,87 @@ +.. _offscreen_guide: + +Offscreen Rendering +=================== + +.. note:: + If you're using a headless server, you'll need to use either EGL (for + GPU-accelerated rendering) or OSMesa (for CPU-only software rendering). + If you're using OSMesa, be sure that you've installed it properly. See + :ref:`osmesa` for details. + +Choosing a Backend +------------------ + +Once you have a scene set up with its geometry, cameras, and lights, +you can render it using the :class:`.OffscreenRenderer`. Pyrender supports +three backends for offscreen rendering: + +- Pyglet, the same engine that runs the viewer. This requires an active + display manager, so you can't run it on a headless server. This is the + default option. +- OSMesa, a software renderer. +- EGL, which allows for GPU-accelerated rendering without a display manager. + +If you want to use OSMesa or EGL, you need to set the ``PYOPENGL_PLATFORM`` +environment variable before importing pyrender or any other OpenGL library. +You can do this at the command line: + +.. code-block:: bash + + PYOPENGL_PLATFORM=osmesa python render.py + +or at the top of your Python script: + +.. code-block:: bash + + # Top of main python script + import os + os.environ['PYOPENGL_PLATFORM'] = 'egl' + +The handle for EGL is ``egl``, and the handle for OSMesa is ``osmesa``. + +Running the Renderer +-------------------- + +Once you've set your environment variable appropriately, create your scene and +then configure the :class:`.OffscreenRenderer` object with a window width, +a window height, and a size for point-cloud points: + +>>> r = pyrender.OffscreenRenderer(viewport_width=640, +... viewport_height=480, +... point_size=1.0) + +Then, just call the :meth:`.OffscreenRenderer.render` function: + +>>> color, depth = r.render(scene) + +.. image:: /_static/scene.png + +This will return a ``(w,h,3)`` channel floating-point color image and +a ``(w,h)`` floating-point depth image rendered from the scene's main camera. + +You can customize the rendering process by using flag options from +:class:`.RenderFlags` and bitwise or-ing them together. For example, +the following code renders a color image with an alpha channel +and enables shadow mapping for all directional lights: + +>>> flags = RenderFlags.RGBA | RenderFlags.SHADOWS_DIRECTIONAL +>>> color, depth = r.render(scene, flags=flags) + +Once you're done with the offscreen renderer, you need to close it before you +can run a different renderer or open the viewer for the same scene: + +>>> r.delete() + +Google CoLab Examples +--------------------- + +For a minimal working example of offscreen rendering using OSMesa, +see the `OSMesa Google CoLab notebook`_. + +.. _OSMesa Google CoLab notebook: https://colab.research.google.com/drive/1Z71mHIc-Sqval92nK290vAsHZRUkCjUx + +For a minimal working example of offscreen rendering using EGL, +see the `EGL Google CoLab notebook`_. + +.. _EGL Google CoLab notebook: https://colab.research.google.com/drive/1rTLHk0qxh4dn8KNe-mCnN8HAWdd2_BEh diff --git a/pyrender/docs/source/examples/quickstart.rst b/pyrender/docs/source/examples/quickstart.rst new file mode 100644 index 0000000000000000000000000000000000000000..ac556419e5206c2ccd4bc985feb1a8c7347310af --- /dev/null +++ b/pyrender/docs/source/examples/quickstart.rst @@ -0,0 +1,71 @@ +.. _quickstart_guide: + +Quickstart +========== + + +Minimal Example for 3D Viewer +----------------------------- +Here is a minimal example of loading and viewing a triangular mesh model +in pyrender. + +>>> import trimesh +>>> import pyrender +>>> fuze_trimesh = trimesh.load('examples/models/fuze.obj') +>>> mesh = pyrender.Mesh.from_trimesh(fuze_trimesh) +>>> scene = pyrender.Scene() +>>> scene.add(mesh) +>>> pyrender.Viewer(scene, use_raymond_lighting=True) + +.. image:: /_static/fuze.png + + +Minimal Example for Offscreen Rendering +--------------------------------------- +.. note:: + If you're using a headless server, make sure that you followed the guide + for installing OSMesa. See :ref:`osmesa`. + +Here is a minimal example of rendering a mesh model offscreen in pyrender. +The only additional necessities are that you need to add lighting and a camera. + +>>> import numpy as np +>>> import trimesh +>>> import pyrender +>>> import matplotlib.pyplot as plt + +>>> fuze_trimesh = trimesh.load('examples/models/fuze.obj') +>>> mesh = pyrender.Mesh.from_trimesh(fuze_trimesh) +>>> scene = pyrender.Scene() +>>> scene.add(mesh) +>>> camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0) +>>> s = np.sqrt(2)/2 +>>> camera_pose = np.array([ +... [0.0, -s, s, 0.3], +... [1.0, 0.0, 0.0, 0.0], +... [0.0, s, s, 0.35], +... [0.0, 0.0, 0.0, 1.0], +... ]) +>>> scene.add(camera, pose=camera_pose) +>>> light = pyrender.SpotLight(color=np.ones(3), intensity=3.0, +... innerConeAngle=np.pi/16.0, +... outerConeAngle=np.pi/6.0) +>>> scene.add(light, pose=camera_pose) +>>> r = pyrender.OffscreenRenderer(400, 400) +>>> color, depth = r.render(scene) +>>> plt.figure() +>>> plt.subplot(1,2,1) +>>> plt.axis('off') +>>> plt.imshow(color) +>>> plt.subplot(1,2,2) +>>> plt.axis('off') +>>> plt.imshow(depth, cmap=plt.cm.gray_r) +>>> plt.show() + +.. image:: /_static/minexcolor.png + :width: 45% + :align: left +.. image:: /_static/minexdepth.png + :width: 45% + :align: right + diff --git a/pyrender/docs/source/examples/scenes.rst b/pyrender/docs/source/examples/scenes.rst new file mode 100644 index 0000000000000000000000000000000000000000..94c243f8b860b9669ac26105fd2b9906054f4568 --- /dev/null +++ b/pyrender/docs/source/examples/scenes.rst @@ -0,0 +1,78 @@ +.. _scene_guide: + +Creating Scenes +=============== + +Before you render anything, you need to put all of your lights, cameras, +and meshes into a scene. The :class:`.Scene` object keeps track of the relative +poses of these primitives by inserting them into :class:`.Node` objects and +keeping them in a directed acyclic graph. + +Adding Objects +-------------- + +To create a :class:`.Scene`, simply call the constructor. You can optionally +specify an ambient light color and a background color: + +>>> scene = pyrender.Scene(ambient_light=[0.02, 0.02, 0.02], +... bg_color=[1.0, 1.0, 1.0]) + +You can add objects to a scene by first creating a :class:`.Node` object +and adding the object and its pose to the :class:`.Node`. Poses are specified +as 4x4 homogenous transformation matrices that are stored in the node's +:attr:`.Node.matrix` attribute. Note that the :class:`.Node` +constructor requires you to specify whether you're adding a mesh, light, +or camera. + +>>> mesh = pyrender.Mesh.from_trimesh(tm) +>>> light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=2.0) +>>> cam = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.414) +>>> nm = pyrender.Node(mesh=mesh, matrix=np.eye(4)) +>>> nl = pyrender.Node(light=light, matrix=np.eye(4)) +>>> nc = pyrender.Node(camera=cam, matrix=np.eye(4)) +>>> scene.add_node(nm) +>>> scene.add_node(nl) +>>> scene.add_node(nc) + +You can also add objects directly to a scene with the :meth:`.Scene.add` function, +which takes care of creating a :class:`.Node` for you. + +>>> scene.add(mesh, pose=np.eye(4)) +>>> scene.add(light, pose=np.eye(4)) +>>> scene.add(cam, pose=np.eye(4)) + +Nodes can be hierarchical, in which case the node's :attr:`.Node.matrix` +specifies that node's pose relative to its parent frame. You can add nodes to +a scene hierarchically by specifying a parent node in your calls to +:meth:`.Scene.add` or :meth:`.Scene.add_node`: + +>>> scene.add_node(nl, parent_node=nc) +>>> scene.add(cam, parent_node=nm) + +If you add multiple cameras to a scene, you can specify which one to render from +by setting the :attr:`.Scene.main_camera_node` attribute. + +Updating Objects +---------------- + +You can update the poses of existing nodes with the :meth:`.Scene.set_pose` +function. Simply call it with a :class:`.Node` that is already in the scene +and the new pose of that node with respect to its parent as a 4x4 homogenous +transformation matrix: + +>>> scene.set_pose(nl, pose=np.eye(4)) + +If you want to get the local pose of a node, you can just access its +:attr:`.Node.matrix` attribute. However, if you want to the get +the pose of a node *with respect to the world frame*, you can call the +:meth:`.Scene.get_pose` method. + +>>> tf = scene.get_pose(nl) + +Removing Objects +---------------- + +Finally, you can remove a :class:`.Node` and all of its children from the +scene with the :meth:`.Scene.remove_node` function: + +>>> scene.remove_node(nl) diff --git a/pyrender/docs/source/examples/viewer.rst b/pyrender/docs/source/examples/viewer.rst new file mode 100644 index 0000000000000000000000000000000000000000..00a7973b46ec7da33b51b65581af6f25c1b1652f --- /dev/null +++ b/pyrender/docs/source/examples/viewer.rst @@ -0,0 +1,61 @@ +.. _viewer_guide: + +Live Scene Viewer +================= + +Standard Usage +-------------- +In addition to the offscreen renderer, Pyrender comes with a live scene viewer. +In its standard invocation, calling the :class:`.Viewer`'s constructor will +immediately pop a viewing window that you can navigate around in. + +>>> pyrender.Viewer(scene) + +By default, the viewer uses your scene's lighting. If you'd like to start with +some additional lighting that moves around with the camera, you can specify that +with: + +>>> pyrender.Viewer(scene, use_raymond_lighting=True) + +For a full list of the many options that the :class:`.Viewer` supports, check out its +documentation. + +.. image:: /_static/rotation.gif + +Running the Viewer in a Separate Thread +--------------------------------------- +If you'd like to animate your models, you'll want to run the viewer in a +separate thread so that you can update the scene while the viewer is running. +To do this, first pop the viewer in a separate thread by calling its constructor +with the ``run_in_thread`` option set: + +>>> v = pyrender.Viewer(scene, run_in_thread=True) + +Then, you can manipulate the :class:`.Scene` while the viewer is running to +animate things. However, be careful to acquire the viewer's +:attr:`.Viewer.render_lock` before editing the scene to prevent data corruption: + +>>> i = 0 +>>> while True: +... pose = np.eye(4) +... pose[:3,3] = [i, 0, 0] +... v.render_lock.acquire() +... scene.set_pose(mesh_node, pose) +... v.render_lock.release() +... i += 0.01 + +.. image:: /_static/scissors.gif + +You can wait on the viewer to be closed manually: + +>>> while v.is_active: +... pass + +Or you can close it from the main thread forcibly. +Make sure to still loop and block for the viewer to actually exit before using +the scene object again. + +>>> v.close_external() +>>> while v.is_active: +... pass + diff --git a/pyrender/docs/source/index.rst b/pyrender/docs/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..baf189ede6bb3435cad5b8795e1937ef1a3c2c56 --- /dev/null +++ b/pyrender/docs/source/index.rst @@ -0,0 +1,41 @@ +.. core documentation master file, created by + sphinx-quickstart on Sun Oct 16 14:33:48 2016. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Pyrender Documentation +======================== +Pyrender is a pure Python (2.7, 3.4, 3.5, 3.6) library for physically-based +rendering and visualization. +It is designed to meet the glTF 2.0 specification_ from Khronos + +.. _specification: https://www.khronos.org/gltf/ + +Pyrender is lightweight, easy to install, and simple to use. +It comes packaged with both an intuitive scene viewer and a headache-free +offscreen renderer with support for GPU-accelerated rendering on headless +servers, which makes it perfect for machine learning applications. +Check out the :ref:`guide` for a full tutorial, or fork me on +Github_. + +.. _Github: https://github.com/mmatl/pyrender + +.. image:: _static/rotation.gif + +.. image:: _static/damaged_helmet.png + +.. toctree:: + :maxdepth: 2 + + install/index.rst + examples/index.rst + api/index.rst + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + diff --git a/pyrender/docs/source/install/index.rst b/pyrender/docs/source/install/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..c785f202d877f8bbaf286c21eddca1925973f75e --- /dev/null +++ b/pyrender/docs/source/install/index.rst @@ -0,0 +1,172 @@ +Installation Guide +================== + +Python Installation +------------------- + +This package is available via ``pip``. + +.. code-block:: bash + + pip install pyrender + +If you're on MacOS, you'll need +to pre-install my fork of ``pyglet``, as the version on PyPI hasn't yet included +my change that enables OpenGL contexts on MacOS. + +.. code-block:: bash + + git clone https://github.com/mmatl/pyglet.git + cd pyglet + pip install . + +.. _osmesa: + +Getting Pyrender Working with OSMesa +------------------------------------ +If you want to render scenes offscreen but don't want to have to +install a display manager or deal with the pains of trying to get +OpenGL to work over SSH, you have two options. + +The first (and preferred) option is using EGL, which enables you to perform +GPU-accelerated rendering on headless servers. +However, you'll need EGL 1.5 to get modern OpenGL contexts. +This comes packaged with NVIDIA's current drivers, but if you are having issues +getting EGL to work with your hardware, you can try using OSMesa, +a software-based offscreen renderer that is included with any Mesa +install. + +If you want to use OSMesa with pyrender, you'll have to perform two additional +installation steps: + +- :ref:`installmesa` +- :ref:`installpyopengl` + +Then, read the offscreen rendering tutorial. See :ref:`offscreen_guide`. + +.. _installmesa: + +Installing OSMesa +***************** + +As a first step, you'll need to rebuild and re-install Mesa with support +for fast offscreen rendering and OpenGL 3+ contexts. +I'd recommend installing from source, but you can also try my ``.deb`` +for Ubuntu 16.04 and up. + +Installing from a Debian Package +******************************** + +If you're running Ubuntu 16.04 or newer, you should be able to install the +required version of Mesa from my ``.deb`` file. + +.. code-block:: bash + + sudo apt update + sudo wget https://github.com/mmatl/travis_debs/raw/master/xenial/mesa_18.3.3-0.deb + sudo dpkg -i ./mesa_18.3.3-0.deb || true + sudo apt install -f + +If this doesn't work, try building from source. + +Building From Source +******************** + +First, install build dependencies via `apt` or your system's package manager. + +.. code-block:: bash + + sudo apt-get install llvm-6.0 freeglut3 freeglut3-dev + +Then, download the current release of Mesa from here_. +Unpack the source and go to the source folder: + +.. _here: https://archive.mesa3d.org/mesa-18.3.3.tar.gz + +.. code-block:: bash + + tar xfv mesa-18.3.3.tar.gz + cd mesa-18.3.3 + +Replace ``PREFIX`` with the path you want to install Mesa at. +If you're not worried about overwriting your default Mesa install, +a good place is at ``/usr/local``. + +Now, configure the installation by running the following command: + +.. code-block:: bash + + ./configure --prefix=PREFIX \ + --enable-opengl --disable-gles1 --disable-gles2 \ + --disable-va --disable-xvmc --disable-vdpau \ + --enable-shared-glapi \ + --disable-texture-float \ + --enable-gallium-llvm --enable-llvm-shared-libs \ + --with-gallium-drivers=swrast,swr \ + --disable-dri --with-dri-drivers= \ + --disable-egl --with-egl-platforms= --disable-gbm \ + --disable-glx \ + --disable-osmesa --enable-gallium-osmesa \ + ac_cv_path_LLVM_CONFIG=llvm-config-6.0 + +Finally, build and install Mesa. + +.. code-block:: bash + + make -j8 + make install + +Finally, if you didn't install Mesa in the system path, +add the following lines to your ``~/.bashrc`` file after +changing ``MESA_HOME`` to your mesa installation path (i.e. what you used as +``PREFIX`` during the configure command). + +.. code-block:: bash + + MESA_HOME=/path/to/your/mesa/installation + export LIBRARY_PATH=$LIBRARY_PATH:$MESA_HOME/lib + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$MESA_HOME/lib + export C_INCLUDE_PATH=$C_INCLUDE_PATH:$MESA_HOME/include/ + export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$MESA_HOME/include/ + +.. _installpyopengl: + +Installing a Compatible Fork of PyOpenGL +**************************************** + +Next, install and use my fork of ``PyOpenGL``. +This fork enables getting modern OpenGL contexts with OSMesa. +My patch has been included in ``PyOpenGL``, but it has not yet been released +on PyPI. + +.. code-block:: bash + + git clone https://github.com/mmatl/pyopengl.git + pip install ./pyopengl + + +Building Documentation +---------------------- + +The online documentation for ``pyrender`` is automatically built by Read The Docs. +Building ``pyrender``'s documentation locally requires a few extra dependencies -- +specifically, `sphinx`_ and a few plugins. + +.. _sphinx: http://www.sphinx-doc.org/en/master/ + +To install the dependencies required, simply change directories into the `pyrender` source and run + +.. code-block:: bash + + $ pip install .[docs] + +Then, go to the ``docs`` directory and run ``make`` with the appropriate target. +For example, + +.. code-block:: bash + + $ cd docs/ + $ make html + +will generate a set of web pages. Any documentation files +generated in this manner can be found in ``docs/build``. diff --git a/pyrender/examples/duck.py b/pyrender/examples/duck.py new file mode 100644 index 0000000000000000000000000000000000000000..9a94bad5bfb30493f7364f2e52cbb4badbccb2c7 --- /dev/null +++ b/pyrender/examples/duck.py @@ -0,0 +1,13 @@ +from pyrender import Mesh, Scene, Viewer +from io import BytesIO +import numpy as np +import trimesh +import requests + +duck_source = "https://github.com/KhronosGroup/glTF-Sample-Models/raw/master/2.0/Duck/glTF-Binary/Duck.glb" + +duck = trimesh.load(BytesIO(requests.get(duck_source).content), file_type='glb') +duckmesh = Mesh.from_trimesh(list(duck.geometry.values())[0]) +scene = Scene(ambient_light=np.array([1.0, 1.0, 1.0, 1.0])) +scene.add(duckmesh) +Viewer(scene) diff --git a/pyrender/examples/example.py b/pyrender/examples/example.py new file mode 100644 index 0000000000000000000000000000000000000000..599a4850a5899cdeb1a76db1c5cf1c91c263cd41 --- /dev/null +++ b/pyrender/examples/example.py @@ -0,0 +1,157 @@ +"""Examples of using pyrender for viewing and offscreen rendering. +""" +import pyglet +pyglet.options['shadow_window'] = False +import os +import numpy as np +import trimesh + +from pyrender import PerspectiveCamera,\ + DirectionalLight, SpotLight, PointLight,\ + MetallicRoughnessMaterial,\ + Primitive, Mesh, Node, Scene,\ + Viewer, OffscreenRenderer, RenderFlags + +#============================================================================== +# Mesh creation +#============================================================================== + +#------------------------------------------------------------------------------ +# Creating textured meshes from trimeshes +#------------------------------------------------------------------------------ + +# Fuze trimesh +fuze_trimesh = trimesh.load('./models/fuze.obj') +fuze_mesh = Mesh.from_trimesh(fuze_trimesh) + +# Drill trimesh +drill_trimesh = trimesh.load('./models/drill.obj') +drill_mesh = Mesh.from_trimesh(drill_trimesh) +drill_pose = np.eye(4) +drill_pose[0,3] = 0.1 +drill_pose[2,3] = -np.min(drill_trimesh.vertices[:,2]) + +# Wood trimesh +wood_trimesh = trimesh.load('./models/wood.obj') +wood_mesh = Mesh.from_trimesh(wood_trimesh) + +# Water bottle trimesh +bottle_gltf = trimesh.load('./models/WaterBottle.glb') +bottle_trimesh = bottle_gltf.geometry[list(bottle_gltf.geometry.keys())[0]] +bottle_mesh = Mesh.from_trimesh(bottle_trimesh) +bottle_pose = np.array([ + [1.0, 0.0, 0.0, 0.1], + [0.0, 0.0, -1.0, -0.16], + [0.0, 1.0, 0.0, 0.13], + [0.0, 0.0, 0.0, 1.0], +]) + +#------------------------------------------------------------------------------ +# Creating meshes with per-vertex colors +#------------------------------------------------------------------------------ +boxv_trimesh = trimesh.creation.box(extents=0.1*np.ones(3)) +boxv_vertex_colors = np.random.uniform(size=(boxv_trimesh.vertices.shape)) +boxv_trimesh.visual.vertex_colors = boxv_vertex_colors +boxv_mesh = Mesh.from_trimesh(boxv_trimesh, smooth=False) + +#------------------------------------------------------------------------------ +# Creating meshes with per-face colors +#------------------------------------------------------------------------------ +boxf_trimesh = trimesh.creation.box(extents=0.1*np.ones(3)) +boxf_face_colors = np.random.uniform(size=boxf_trimesh.faces.shape) +boxf_trimesh.visual.face_colors = boxf_face_colors +boxf_mesh = Mesh.from_trimesh(boxf_trimesh, smooth=False) + +#------------------------------------------------------------------------------ +# Creating meshes from point clouds +#------------------------------------------------------------------------------ +points = trimesh.creation.icosphere(radius=0.05).vertices +point_colors = np.random.uniform(size=points.shape) +points_mesh = Mesh.from_points(points, colors=point_colors) + +#============================================================================== +# Light creation +#============================================================================== + +direc_l = DirectionalLight(color=np.ones(3), intensity=1.0) +spot_l = SpotLight(color=np.ones(3), intensity=10.0, + innerConeAngle=np.pi/16, outerConeAngle=np.pi/6) +point_l = PointLight(color=np.ones(3), intensity=10.0) + +#============================================================================== +# Camera creation +#============================================================================== + +cam = PerspectiveCamera(yfov=(np.pi / 3.0)) +cam_pose = np.array([ + [0.0, -np.sqrt(2)/2, np.sqrt(2)/2, 0.5], + [1.0, 0.0, 0.0, 0.0], + [0.0, np.sqrt(2)/2, np.sqrt(2)/2, 0.4], + [0.0, 0.0, 0.0, 1.0] +]) + +#============================================================================== +# Scene creation +#============================================================================== + +scene = Scene(ambient_light=np.array([0.02, 0.02, 0.02, 1.0])) + +#============================================================================== +# Adding objects to the scene +#============================================================================== + +#------------------------------------------------------------------------------ +# By manually creating nodes +#------------------------------------------------------------------------------ +fuze_node = Node(mesh=fuze_mesh, translation=np.array([0.1, 0.15, -np.min(fuze_trimesh.vertices[:,2])])) +scene.add_node(fuze_node) +boxv_node = Node(mesh=boxv_mesh, translation=np.array([-0.1, 0.10, 0.05])) +scene.add_node(boxv_node) +boxf_node = Node(mesh=boxf_mesh, translation=np.array([-0.1, -0.10, 0.05])) +scene.add_node(boxf_node) + +#------------------------------------------------------------------------------ +# By using the add() utility function +#------------------------------------------------------------------------------ +drill_node = scene.add(drill_mesh, pose=drill_pose) +bottle_node = scene.add(bottle_mesh, pose=bottle_pose) +wood_node = scene.add(wood_mesh) +direc_l_node = scene.add(direc_l, pose=cam_pose) +spot_l_node = scene.add(spot_l, pose=cam_pose) + +#============================================================================== +# Using the viewer with a default camera +#============================================================================== + +v = Viewer(scene, shadows=True) + +#============================================================================== +# Using the viewer with a pre-specified camera +#============================================================================== +cam_node = scene.add(cam, pose=cam_pose) +v = Viewer(scene, central_node=drill_node) + +#============================================================================== +# Rendering offscreen from that camera +#============================================================================== + +r = OffscreenRenderer(viewport_width=640*2, viewport_height=480*2) +color, depth = r.render(scene) + +import matplotlib.pyplot as plt +plt.figure() +plt.imshow(color) +plt.show() + +#============================================================================== +# Segmask rendering +#============================================================================== + +nm = {node: 20*(i + 1) for i, node in enumerate(scene.mesh_nodes)} +seg = r.render(scene, RenderFlags.SEG, nm)[0] +plt.figure() +plt.imshow(seg) +plt.show() + +r.delete() + diff --git a/pyrender/pyrender/__init__.py b/pyrender/pyrender/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3709846823b7c4b71b22da0e24d63d805528a8 --- /dev/null +++ b/pyrender/pyrender/__init__.py @@ -0,0 +1,24 @@ +from .camera import (Camera, PerspectiveCamera, OrthographicCamera, + IntrinsicsCamera) +from .light import Light, PointLight, DirectionalLight, SpotLight +from .sampler import Sampler +from .texture import Texture +from .material import Material, MetallicRoughnessMaterial +from .primitive import Primitive +from .mesh import Mesh +from .node import Node +from .scene import Scene +from .renderer import Renderer +from .viewer import Viewer +from .offscreen import OffscreenRenderer +from .version import __version__ +from .constants import RenderFlags, TextAlign, GLTF + +__all__ = [ + 'Camera', 'PerspectiveCamera', 'OrthographicCamera', 'IntrinsicsCamera', + 'Light', 'PointLight', 'DirectionalLight', 'SpotLight', + 'Sampler', 'Texture', 'Material', 'MetallicRoughnessMaterial', + 'Primitive', 'Mesh', 'Node', 'Scene', 'Renderer', 'Viewer', + 'OffscreenRenderer', '__version__', 'RenderFlags', 'TextAlign', + 'GLTF' +] diff --git a/pyrender/pyrender/camera.py b/pyrender/pyrender/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..e019358039033c3a372c990ebad3151258c3651d --- /dev/null +++ b/pyrender/pyrender/camera.py @@ -0,0 +1,437 @@ +"""Virtual cameras compliant with the glTF 2.0 specification as described at +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-camera + +Author: Matthew Matl +""" +import abc +import numpy as np +import six +import sys + +from .constants import DEFAULT_Z_NEAR, DEFAULT_Z_FAR + + +@six.add_metaclass(abc.ABCMeta) +class Camera(object): + """Abstract base class for all cameras. + + Note + ---- + Camera poses are specified in the OpenGL format, + where the z axis points away from the view direction and the + x and y axes point to the right and up in the image plane, respectively. + + Parameters + ---------- + znear : float + The floating-point distance to the near clipping plane. + zfar : float + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + znear=DEFAULT_Z_NEAR, + zfar=DEFAULT_Z_FAR, + name=None): + self.name = name + self.znear = znear + self.zfar = zfar + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def znear(self): + """float : The distance to the near clipping plane. + """ + return self._znear + + @znear.setter + def znear(self, value): + value = float(value) + if value < 0: + raise ValueError('z-near must be >= 0.0') + self._znear = value + + @property + def zfar(self): + """float : The distance to the far clipping plane. + """ + return self._zfar + + @zfar.setter + def zfar(self, value): + value = float(value) + if value <= 0 or value <= self.znear: + raise ValueError('zfar must be >0 and >znear') + self._zfar = value + + @abc.abstractmethod + def get_projection_matrix(self, width=None, height=None): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + height : int + Height of the current viewport, in pixels. + """ + pass + + +class PerspectiveCamera(Camera): + + """A perspective camera for perspective projection. + + Parameters + ---------- + yfov : float + The floating-point vertical field of view in radians. + znear : float + The floating-point distance to the near clipping plane. + If not specified, defaults to 0.05. + zfar : float, optional + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + If None, the camera uses an infinite projection matrix. + aspectRatio : float, optional + The floating-point aspect ratio of the field of view. + If not specified, the camera uses the viewport's aspect ratio. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + yfov, + znear=DEFAULT_Z_NEAR, + zfar=None, + aspectRatio=None, + name=None): + super(PerspectiveCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + + self.yfov = yfov + self.aspectRatio = aspectRatio + + @property + def yfov(self): + """float : The vertical field of view in radians. + """ + return self._yfov + + @yfov.setter + def yfov(self, value): + value = float(value) + if value <= 0.0: + raise ValueError('Field of view must be positive') + self._yfov = value + + @property + def zfar(self): + """float : The distance to the far clipping plane. + """ + return self._zfar + + @zfar.setter + def zfar(self, value): + if value is not None: + value = float(value) + if value <= 0 or value <= self.znear: + raise ValueError('zfar must be >0 and >znear') + self._zfar = value + + @property + def aspectRatio(self): + """float : The ratio of the width to the height of the field of view. + """ + return self._aspectRatio + + @aspectRatio.setter + def aspectRatio(self, value): + if value is not None: + value = float(value) + if value <= 0.0: + raise ValueError('Aspect ratio must be positive') + self._aspectRatio = value + + def get_projection_matrix(self, width=None, height=None): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + height : int + Height of the current viewport, in pixels. + """ + aspect_ratio = self.aspectRatio + if aspect_ratio is None: + if width is None or height is None: + raise ValueError('Aspect ratio of camera must be defined') + aspect_ratio = float(width) / float(height) + + a = aspect_ratio + t = np.tan(self.yfov / 2.0) + n = self.znear + f = self.zfar + + P = np.zeros((4,4)) + P[0][0] = 1.0 / (a * t) + P[1][1] = 1.0 / t + P[3][2] = -1.0 + + if f is None: + P[2][2] = -1.0 + P[2][3] = -2.0 * n + else: + P[2][2] = (f + n) / (n - f) + P[2][3] = (2 * f * n) / (n - f) + + return P + + +class OrthographicCamera(Camera): + """An orthographic camera for orthographic projection. + + Parameters + ---------- + xmag : float + The floating-point horizontal magnification of the view. + ymag : float + The floating-point vertical magnification of the view. + znear : float + The floating-point distance to the near clipping plane. + If not specified, defaults to 0.05. + zfar : float + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + If not specified, defaults to 100.0. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + xmag, + ymag, + znear=DEFAULT_Z_NEAR, + zfar=DEFAULT_Z_FAR, + name=None): + super(OrthographicCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + + self.xmag = xmag + self.ymag = ymag + + @property + def xmag(self): + """float : The horizontal magnification of the view. + """ + return self._xmag + + @xmag.setter + def xmag(self, value): + value = float(value) + if value <= 0.0: + raise ValueError('X magnification must be positive') + self._xmag = value + + @property + def ymag(self): + """float : The vertical magnification of the view. + """ + return self._ymag + + @ymag.setter + def ymag(self, value): + value = float(value) + if value <= 0.0: + raise ValueError('Y magnification must be positive') + self._ymag = value + + @property + def znear(self): + """float : The distance to the near clipping plane. + """ + return self._znear + + @znear.setter + def znear(self, value): + value = float(value) + if value <= 0: + raise ValueError('z-near must be > 0.0') + self._znear = value + + def get_projection_matrix(self, width=None, height=None): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + Unused in this function. + height : int + Height of the current viewport, in pixels. + Unused in this function. + """ + xmag = self.xmag + ymag = self.ymag + + # If screen width/height defined, rescale xmag + if width is not None and height is not None: + xmag = width / height * ymag + + n = self.znear + f = self.zfar + P = np.zeros((4,4)) + P[0][0] = 1.0 / xmag + P[1][1] = 1.0 / ymag + P[2][2] = 2.0 / (n - f) + P[2][3] = (f + n) / (n - f) + P[3][3] = 1.0 + return P + + +class IntrinsicsCamera(Camera): + """A perspective camera with custom intrinsics. + + Parameters + ---------- + fx : float + X-axis focal length in pixels. + fy : float + Y-axis focal length in pixels. + cx : float + X-axis optical center in pixels. + cy : float + Y-axis optical center in pixels. + znear : float + The floating-point distance to the near clipping plane. + If not specified, defaults to 0.05. + zfar : float + The floating-point distance to the far clipping plane. + ``zfar`` must be greater than ``znear``. + If not specified, defaults to 100.0. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + fx, + fy, + cx, + cy, + znear=DEFAULT_Z_NEAR, + zfar=DEFAULT_Z_FAR, + name=None): + super(IntrinsicsCamera, self).__init__( + znear=znear, + zfar=zfar, + name=name, + ) + + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + + @property + def fx(self): + """float : X-axis focal length in meters. + """ + return self._fx + + @fx.setter + def fx(self, value): + self._fx = float(value) + + @property + def fy(self): + """float : Y-axis focal length in meters. + """ + return self._fy + + @fy.setter + def fy(self, value): + self._fy = float(value) + + @property + def cx(self): + """float : X-axis optical center in pixels. + """ + return self._cx + + @cx.setter + def cx(self, value): + self._cx = float(value) + + @property + def cy(self): + """float : Y-axis optical center in pixels. + """ + return self._cy + + @cy.setter + def cy(self, value): + self._cy = float(value) + + def get_projection_matrix(self, width, height): + """Return the OpenGL projection matrix for this camera. + + Parameters + ---------- + width : int + Width of the current viewport, in pixels. + height : int + Height of the current viewport, in pixels. + """ + width = float(width) + height = float(height) + + cx, cy = self.cx, self.cy + fx, fy = self.fx, self.fy + if sys.platform == 'darwin': + cx = self.cx * 2.0 + cy = self.cy * 2.0 + fx = self.fx * 2.0 + fy = self.fy * 2.0 + + P = np.zeros((4,4)) + P[0][0] = 2.0 * fx / width + P[1][1] = 2.0 * fy / height + P[0][2] = 1.0 - 2.0 * cx / width + P[1][2] = 2.0 * cy / height - 1.0 + P[3][2] = -1.0 + + n = self.znear + f = self.zfar + if f is None: + P[2][2] = -1.0 + P[2][3] = -2.0 * n + else: + P[2][2] = (f + n) / (n - f) + P[2][3] = (2 * f * n) / (n - f) + + return P + + +__all__ = ['Camera', 'PerspectiveCamera', 'OrthographicCamera', + 'IntrinsicsCamera'] diff --git a/pyrender/pyrender/constants.py b/pyrender/pyrender/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5785b6fdb21910a174252c5af2f05b40ece4a5 --- /dev/null +++ b/pyrender/pyrender/constants.py @@ -0,0 +1,149 @@ +DEFAULT_Z_NEAR = 0.05 # Near clipping plane, in meters +DEFAULT_Z_FAR = 100.0 # Far clipping plane, in meters +DEFAULT_SCENE_SCALE = 2.0 # Default scene scale +MAX_N_LIGHTS = 4 # Maximum number of lights of each type allowed +TARGET_OPEN_GL_MAJOR = 4 # Target OpenGL Major Version +TARGET_OPEN_GL_MINOR = 1 # Target OpenGL Minor Version +MIN_OPEN_GL_MAJOR = 3 # Minimum OpenGL Major Version +MIN_OPEN_GL_MINOR = 3 # Minimum OpenGL Minor Version +FLOAT_SZ = 4 # Byte size of GL float32 +UINT_SZ = 4 # Byte size of GL uint32 +SHADOW_TEX_SZ = 2048 # Width and Height of Shadow Textures +TEXT_PADDING = 20 # Width of padding for rendering text (px) + + +# Flags for render type +class RenderFlags(object): + """Flags for rendering in the scene. + + Combine them with the bitwise or. For example, + + >>> flags = OFFSCREEN | SHADOWS_DIRECTIONAL | VERTEX_NORMALS + + would result in an offscreen render with directional shadows and + vertex normals enabled. + """ + NONE = 0 + """Normal PBR Render.""" + DEPTH_ONLY = 1 + """Only render the depth buffer.""" + OFFSCREEN = 2 + """Render offscreen and return the depth and (optionally) color buffers.""" + FLIP_WIREFRAME = 4 + """Invert the status of wireframe rendering for each mesh.""" + ALL_WIREFRAME = 8 + """Render all meshes as wireframes.""" + ALL_SOLID = 16 + """Render all meshes as solids.""" + SHADOWS_DIRECTIONAL = 32 + """Render shadows for directional lights.""" + SHADOWS_POINT = 64 + """Render shadows for point lights.""" + SHADOWS_SPOT = 128 + """Render shadows for spot lights.""" + SHADOWS_ALL = 32 | 64 | 128 + """Render shadows for all lights.""" + VERTEX_NORMALS = 256 + """Render vertex normals.""" + FACE_NORMALS = 512 + """Render face normals.""" + SKIP_CULL_FACES = 1024 + """Do not cull back faces.""" + RGBA = 2048 + """Render the color buffer with the alpha channel enabled.""" + FLAT = 4096 + """Render the color buffer flat, with no lighting computations.""" + SEG = 8192 + + +class TextAlign: + """Text alignment options for captions. + + Only use one at a time. + """ + CENTER = 0 + """Center the text by width and height.""" + CENTER_LEFT = 1 + """Center the text by height and left-align it.""" + CENTER_RIGHT = 2 + """Center the text by height and right-align it.""" + BOTTOM_LEFT = 3 + """Put the text in the bottom-left corner.""" + BOTTOM_RIGHT = 4 + """Put the text in the bottom-right corner.""" + BOTTOM_CENTER = 5 + """Center the text by width and fix it to the bottom.""" + TOP_LEFT = 6 + """Put the text in the top-left corner.""" + TOP_RIGHT = 7 + """Put the text in the top-right corner.""" + TOP_CENTER = 8 + """Center the text by width and fix it to the top.""" + + +class GLTF(object): + """Options for GL objects.""" + NEAREST = 9728 + """Nearest neighbor interpolation.""" + LINEAR = 9729 + """Linear interpolation.""" + NEAREST_MIPMAP_NEAREST = 9984 + """Nearest mipmapping.""" + LINEAR_MIPMAP_NEAREST = 9985 + """Linear mipmapping.""" + NEAREST_MIPMAP_LINEAR = 9986 + """Nearest mipmapping.""" + LINEAR_MIPMAP_LINEAR = 9987 + """Linear mipmapping.""" + CLAMP_TO_EDGE = 33071 + """Clamp to the edge of the texture.""" + MIRRORED_REPEAT = 33648 + """Mirror the texture.""" + REPEAT = 10497 + """Repeat the texture.""" + POINTS = 0 + """Render as points.""" + LINES = 1 + """Render as lines.""" + LINE_LOOP = 2 + """Render as a line loop.""" + LINE_STRIP = 3 + """Render as a line strip.""" + TRIANGLES = 4 + """Render as triangles.""" + TRIANGLE_STRIP = 5 + """Render as a triangle strip.""" + TRIANGLE_FAN = 6 + """Render as a triangle fan.""" + + +class BufFlags(object): + POSITION = 0 + NORMAL = 1 + TANGENT = 2 + TEXCOORD_0 = 4 + TEXCOORD_1 = 8 + COLOR_0 = 16 + JOINTS_0 = 32 + WEIGHTS_0 = 64 + + +class TexFlags(object): + NONE = 0 + NORMAL = 1 + OCCLUSION = 2 + EMISSIVE = 4 + BASE_COLOR = 8 + METALLIC_ROUGHNESS = 16 + DIFFUSE = 32 + SPECULAR_GLOSSINESS = 64 + + +class ProgramFlags: + NONE = 0 + USE_MATERIAL = 1 + VERTEX_NORMALS = 2 + FACE_NORMALS = 4 + + +__all__ = ['RenderFlags', 'TextAlign', 'GLTF'] diff --git a/pyrender/pyrender/font.py b/pyrender/pyrender/font.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac530d7b949f50314a0d9cf5d744bedcace0571 --- /dev/null +++ b/pyrender/pyrender/font.py @@ -0,0 +1,272 @@ +"""Font texture loader and processor. + +Author: Matthew Matl +""" +import freetype +import numpy as np +import os + +import OpenGL +from OpenGL.GL import * + +from .constants import TextAlign, FLOAT_SZ +from .texture import Texture +from .sampler import Sampler + + +class FontCache(object): + """A cache for fonts. + """ + + def __init__(self, font_dir=None): + self._font_cache = {} + self.font_dir = font_dir + if self.font_dir is None: + base_dir, _ = os.path.split(os.path.realpath(__file__)) + self.font_dir = os.path.join(base_dir, 'fonts') + + def get_font(self, font_name, font_pt): + # If it's a file, load it directly, else, try to load from font dir. + if os.path.isfile(font_name): + font_filename = font_name + _, font_name = os.path.split(font_name) + font_name, _ = os.path.split(font_name) + else: + font_filename = os.path.join(self.font_dir, font_name) + '.ttf' + + cid = OpenGL.contextdata.getContext() + key = (cid, font_name, int(font_pt)) + + if key not in self._font_cache: + self._font_cache[key] = Font(font_filename, font_pt) + return self._font_cache[key] + + def clear(self): + for key in self._font_cache: + self._font_cache[key].delete() + self._font_cache = {} + + +class Character(object): + """A single character, with its texture and attributes. + """ + + def __init__(self, texture, size, bearing, advance): + self.texture = texture + self.size = size + self.bearing = bearing + self.advance = advance + + +class Font(object): + """A font object. + + Parameters + ---------- + font_file : str + The file to load the font from. + font_pt : int + The height of the font in pixels. + """ + + def __init__(self, font_file, font_pt=40): + self.font_file = font_file + self.font_pt = int(font_pt) + self._face = freetype.Face(font_file) + self._face.set_pixel_sizes(0, font_pt) + self._character_map = {} + + for i in range(0, 128): + + # Generate texture + face = self._face + face.load_char(chr(i)) + buf = face.glyph.bitmap.buffer + src = (np.array(buf) / 255.0).astype(np.float32) + src = src.reshape((face.glyph.bitmap.rows, + face.glyph.bitmap.width)) + tex = Texture( + sampler=Sampler( + magFilter=GL_LINEAR, + minFilter=GL_LINEAR, + wrapS=GL_CLAMP_TO_EDGE, + wrapT=GL_CLAMP_TO_EDGE + ), + source=src, + source_channels='R', + ) + character = Character( + texture=tex, + size=np.array([face.glyph.bitmap.width, + face.glyph.bitmap.rows]), + bearing=np.array([face.glyph.bitmap_left, + face.glyph.bitmap_top]), + advance=face.glyph.advance.x + ) + self._character_map[chr(i)] = character + + self._vbo = None + self._vao = None + + @property + def font_file(self): + """str : The file the font was loaded from. + """ + return self._font_file + + @font_file.setter + def font_file(self, value): + self._font_file = value + + @property + def font_pt(self): + """int : The height of the font in pixels. + """ + return self._font_pt + + @font_pt.setter + def font_pt(self, value): + self._font_pt = int(value) + + def _add_to_context(self): + + self._vao = glGenVertexArrays(1) + glBindVertexArray(self._vao) + self._vbo = glGenBuffers(1) + glBindBuffer(GL_ARRAY_BUFFER, self._vbo) + glBufferData(GL_ARRAY_BUFFER, FLOAT_SZ * 6 * 4, None, GL_DYNAMIC_DRAW) + glEnableVertexAttribArray(0) + glVertexAttribPointer( + 0, 4, GL_FLOAT, GL_FALSE, 4 * FLOAT_SZ, ctypes.c_void_p(0) + ) + glBindVertexArray(0) + + glPixelStorei(GL_UNPACK_ALIGNMENT, 1) + for c in self._character_map: + ch = self._character_map[c] + if not ch.texture._in_context(): + ch.texture._add_to_context() + + def _remove_from_context(self): + for c in self._character_map: + ch = self._character_map[c] + ch.texture.delete() + if self._vao is not None: + glDeleteVertexArrays(1, [self._vao]) + glDeleteBuffers(1, [self._vbo]) + self._vao = None + self._vbo = None + + def _in_context(self): + return self._vao is not None + + def _bind(self): + glBindVertexArray(self._vao) + + def _unbind(self): + glBindVertexArray(0) + + def delete(self): + self._unbind() + self._remove_from_context() + + def render_string(self, text, x, y, scale=1.0, + align=TextAlign.BOTTOM_LEFT): + """Render a string to the current view buffer. + + Note + ---- + Assumes correct shader program already bound w/ uniforms set. + + Parameters + ---------- + text : str + The text to render. + x : int + Horizontal pixel location of text. + y : int + Vertical pixel location of text. + scale : int + Scaling factor for text. + align : int + One of the TextAlign options which specifies where the ``x`` + and ``y`` parameters lie on the text. For example, + :attr:`.TextAlign.BOTTOM_LEFT` means that ``x`` and ``y`` indicate + the position of the bottom-left corner of the textbox. + """ + glActiveTexture(GL_TEXTURE0) + glEnable(GL_BLEND) + glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + glDisable(GL_DEPTH_TEST) + glPolygonMode(GL_FRONT_AND_BACK, GL_FILL) + self._bind() + + # Determine width and height of text relative to x, y + width = 0.0 + height = 0.0 + for c in text: + ch = self._character_map[c] + height = max(height, ch.bearing[1] * scale) + width += (ch.advance >> 6) * scale + + # Determine offsets based on alignments + xoff = 0 + yoff = 0 + if align == TextAlign.BOTTOM_RIGHT: + xoff = -width + elif align == TextAlign.BOTTOM_CENTER: + xoff = -width / 2.0 + elif align == TextAlign.TOP_LEFT: + yoff = -height + elif align == TextAlign.TOP_RIGHT: + yoff = -height + xoff = -width + elif align == TextAlign.TOP_CENTER: + yoff = -height + xoff = -width / 2.0 + elif align == TextAlign.CENTER: + xoff = -width / 2.0 + yoff = -height / 2.0 + elif align == TextAlign.CENTER_LEFT: + yoff = -height / 2.0 + elif align == TextAlign.CENTER_RIGHT: + xoff = -width + yoff = -height / 2.0 + + x += xoff + y += yoff + + ch = None + for c in text: + ch = self._character_map[c] + xpos = x + ch.bearing[0] * scale + ypos = y - (ch.size[1] - ch.bearing[1]) * scale + w = ch.size[0] * scale + h = ch.size[1] * scale + + vertices = np.array([ + [xpos, ypos, 0.0, 0.0], + [xpos + w, ypos, 1.0, 0.0], + [xpos + w, ypos + h, 1.0, 1.0], + [xpos + w, ypos + h, 1.0, 1.0], + [xpos, ypos + h, 0.0, 1.0], + [xpos, ypos, 0.0, 0.0], + ], dtype=np.float32) + + ch.texture._bind() + + glBindBuffer(GL_ARRAY_BUFFER, self._vbo) + glBufferData( + GL_ARRAY_BUFFER, FLOAT_SZ * 6 * 4, vertices, GL_DYNAMIC_DRAW + ) + # TODO MAKE THIS MORE EFFICIENT, lgBufferSubData is broken + # glBufferSubData( + # GL_ARRAY_BUFFER, 0, 6 * 4 * FLOAT_SZ, + # np.ascontiguousarray(vertices.flatten) + # ) + glDrawArrays(GL_TRIANGLES, 0, 6) + x += (ch.advance >> 6) * scale + + self._unbind() + if ch: + ch.texture._unbind() diff --git a/pyrender/pyrender/fonts/OpenSans-Bold.ttf b/pyrender/pyrender/fonts/OpenSans-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..fd79d43bea0293ac1b20e8aca1142627983d2c07 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Bold.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-BoldItalic.ttf b/pyrender/pyrender/fonts/OpenSans-BoldItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..9bc800958a421d937fc392e00beaef4eea76dc71 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-BoldItalic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-ExtraBold.ttf b/pyrender/pyrender/fonts/OpenSans-ExtraBold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..21f6f84a0799946fc4ae02c52b27e61c3762c745 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-ExtraBold.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-ExtraBoldItalic.ttf b/pyrender/pyrender/fonts/OpenSans-ExtraBoldItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..31cb688340eff462dddf47efbb4dfef66cb7fbed Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-ExtraBoldItalic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Italic.ttf b/pyrender/pyrender/fonts/OpenSans-Italic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..c90da48ff3b8ad6167236d70c48df4d7b5de3bbb Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Italic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Light.ttf b/pyrender/pyrender/fonts/OpenSans-Light.ttf new file mode 100644 index 0000000000000000000000000000000000000000..0d381897da20345fa63112f19042561f44ee3aa0 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Light.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-LightItalic.ttf b/pyrender/pyrender/fonts/OpenSans-LightItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..68299c4bc6b5b7adfff2c9aee4aed7c1547100ef Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-LightItalic.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Regular.ttf b/pyrender/pyrender/fonts/OpenSans-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..db433349b7047f72f40072630c1bc110620bf09e Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Regular.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-Semibold.ttf b/pyrender/pyrender/fonts/OpenSans-Semibold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..1a7679e3949fb045f152f456bc4adad31e8b9f55 Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-Semibold.ttf differ diff --git a/pyrender/pyrender/fonts/OpenSans-SemiboldItalic.ttf b/pyrender/pyrender/fonts/OpenSans-SemiboldItalic.ttf new file mode 100644 index 0000000000000000000000000000000000000000..59b6d16b065f6baa6f70ddbd4322a4f44bb9636a Binary files /dev/null and b/pyrender/pyrender/fonts/OpenSans-SemiboldItalic.ttf differ diff --git a/pyrender/pyrender/light.py b/pyrender/pyrender/light.py new file mode 100644 index 0000000000000000000000000000000000000000..333d9e4e553a245c259251a89b69cb46b73b1278 --- /dev/null +++ b/pyrender/pyrender/light.py @@ -0,0 +1,385 @@ +"""Punctual light sources as defined by the glTF 2.0 KHR extension at +https://github.com/KhronosGroup/glTF/tree/master/extensions/2.0/Khronos/KHR_lights_punctual + +Author: Matthew Matl +""" +import abc +import numpy as np +import six + +from OpenGL.GL import * + +from .utils import format_color_vector +from .texture import Texture +from .constants import SHADOW_TEX_SZ +from .camera import OrthographicCamera, PerspectiveCamera + + + +@six.add_metaclass(abc.ABCMeta) +class Light(object): + """Base class for all light objects. + + Parameters + ---------- + color : (3,) float + RGB value for the light's color in linear space. + intensity : float + Brightness of light. The units that this is defined in depend on the + type of light. Point and spot lights use luminous intensity in candela + (lm/sr), while directional lights use illuminance in lux (lm/m2). + name : str, optional + Name of the light. + """ + def __init__(self, + color=None, + intensity=None, + name=None): + + if color is None: + color = np.ones(3) + if intensity is None: + intensity = 1.0 + + self.name = name + self.color = color + self.intensity = intensity + self._shadow_camera = None + self._shadow_texture = None + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def color(self): + """(3,) float : The light's color. + """ + return self._color + + @color.setter + def color(self, value): + self._color = format_color_vector(value, 3) + + @property + def intensity(self): + """float : The light's intensity in candela or lux. + """ + return self._intensity + + @intensity.setter + def intensity(self, value): + self._intensity = float(value) + + @property + def shadow_texture(self): + """:class:`.Texture` : A texture used to hold shadow maps for this light. + """ + return self._shadow_texture + + @shadow_texture.setter + def shadow_texture(self, value): + if self._shadow_texture is not None: + if self._shadow_texture._in_context(): + self._shadow_texture.delete() + self._shadow_texture = value + + @abc.abstractmethod + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + pass + + @abc.abstractmethod + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + pass + + +class DirectionalLight(Light): + """Directional lights are light sources that act as though they are + infinitely far away and emit light in the direction of the local -z axis. + This light type inherits the orientation of the node that it belongs to; + position and scale are ignored except for their effect on the inherited + node orientation. Because it is at an infinite distance, the light is + not attenuated. Its intensity is defined in lumens per metre squared, + or lux (lm/m2). + + Parameters + ---------- + color : (3,) float, optional + RGB value for the light's color in linear space. Defaults to white + (i.e. [1.0, 1.0, 1.0]). + intensity : float, optional + Brightness of light, in lux (lm/m^2). Defaults to 1.0 + name : str, optional + Name of the light. + """ + + def __init__(self, + color=None, + intensity=None, + name=None): + super(DirectionalLight, self).__init__( + color=color, + intensity=intensity, + name=name, + ) + + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + if size is None: + size = SHADOW_TEX_SZ + self.shadow_texture = Texture(width=size, height=size, + source_channels='D', data_format=GL_FLOAT) + + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + return OrthographicCamera( + znear=0.01 * scene_scale, + zfar=10 * scene_scale, + xmag=scene_scale, + ymag=scene_scale + ) + + +class PointLight(Light): + """Point lights emit light in all directions from their position in space; + rotation and scale are ignored except for their effect on the inherited + node position. The brightness of the light attenuates in a physically + correct manner as distance increases from the light's position (i.e. + brightness goes like the inverse square of the distance). Point light + intensity is defined in candela, which is lumens per square radian (lm/sr). + + Parameters + ---------- + color : (3,) float + RGB value for the light's color in linear space. + intensity : float + Brightness of light in candela (lm/sr). + range : float + Cutoff distance at which light's intensity may be considered to + have reached zero. If None, the range is assumed to be infinite. + name : str, optional + Name of the light. + """ + + def __init__(self, + color=None, + intensity=None, + range=None, + name=None): + super(PointLight, self).__init__( + color=color, + intensity=intensity, + name=name, + ) + self.range = range + + @property + def range(self): + """float : The cutoff distance for the light. + """ + return self._range + + @range.setter + def range(self, value): + if value is not None: + value = float(value) + if value <= 0: + raise ValueError('Range must be > 0') + self._range = value + self._range = value + + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + raise NotImplementedError('Shadows not implemented for point lights') + + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + raise NotImplementedError('Shadows not implemented for point lights') + + +class SpotLight(Light): + """Spot lights emit light in a cone in the direction of the local -z axis. + The angle and falloff of the cone is defined using two numbers, the + ``innerConeAngle`` and ``outerConeAngle``. + As with point lights, the brightness + also attenuates in a physically correct manner as distance increases from + the light's position (i.e. brightness goes like the inverse square of the + distance). Spot light intensity refers to the brightness inside the + ``innerConeAngle`` (and at the location of the light) and is defined in + candela, which is lumens per square radian (lm/sr). A spot light's position + and orientation are inherited from its node transform. Inherited scale does + not affect cone shape, and is ignored except for its effect on position + and orientation. + + Parameters + ---------- + color : (3,) float + RGB value for the light's color in linear space. + intensity : float + Brightness of light in candela (lm/sr). + range : float + Cutoff distance at which light's intensity may be considered to + have reached zero. If None, the range is assumed to be infinite. + innerConeAngle : float + Angle, in radians, from centre of spotlight where falloff begins. + Must be greater than or equal to ``0`` and less + than ``outerConeAngle``. Defaults to ``0``. + outerConeAngle : float + Angle, in radians, from centre of spotlight where falloff ends. + Must be greater than ``innerConeAngle`` and less than or equal to + ``PI / 2.0``. Defaults to ``PI / 4.0``. + name : str, optional + Name of the light. + """ + + def __init__(self, + color=None, + intensity=None, + range=None, + innerConeAngle=0.0, + outerConeAngle=(np.pi / 4.0), + name=None): + super(SpotLight, self).__init__( + name=name, + color=color, + intensity=intensity, + ) + self.outerConeAngle = outerConeAngle + self.innerConeAngle = innerConeAngle + self.range = range + + @property + def innerConeAngle(self): + """float : The inner cone angle in radians. + """ + return self._innerConeAngle + + @innerConeAngle.setter + def innerConeAngle(self, value): + if value < 0.0 or value > self.outerConeAngle: + raise ValueError('Invalid value for inner cone angle') + self._innerConeAngle = float(value) + + @property + def outerConeAngle(self): + """float : The outer cone angle in radians. + """ + return self._outerConeAngle + + @outerConeAngle.setter + def outerConeAngle(self, value): + if value < 0.0 or value > np.pi / 2.0 + 1e-9: + raise ValueError('Invalid value for outer cone angle') + self._outerConeAngle = float(value) + + @property + def range(self): + """float : The cutoff distance for the light. + """ + return self._range + + @range.setter + def range(self, value): + if value is not None: + value = float(value) + if value <= 0: + raise ValueError('Range must be > 0') + self._range = value + self._range = value + + def _generate_shadow_texture(self, size=None): + """Generate a shadow texture for this light. + + Parameters + ---------- + size : int, optional + Size of texture map. Must be a positive power of two. + """ + if size is None: + size = SHADOW_TEX_SZ + self.shadow_texture = Texture(width=size, height=size, + source_channels='D', data_format=GL_FLOAT) + + def _get_shadow_camera(self, scene_scale): + """Generate and return a shadow mapping camera for this light. + + Parameters + ---------- + scene_scale : float + Length of scene's bounding box diagonal. + + Returns + ------- + camera : :class:`.Camera` + The camera used to render shadowmaps for this light. + """ + return PerspectiveCamera( + znear=0.01 * scene_scale, + zfar=10 * scene_scale, + yfov=np.clip(2 * self.outerConeAngle + np.pi / 16.0, 0.0, np.pi), + aspectRatio=1.0 + ) + + +__all__ = ['Light', 'DirectionalLight', 'SpotLight', 'PointLight'] diff --git a/pyrender/pyrender/material.py b/pyrender/pyrender/material.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce9c2d184ed213c84b015e36bea558cd1efc6b7 --- /dev/null +++ b/pyrender/pyrender/material.py @@ -0,0 +1,707 @@ +"""Material properties, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-material +and +https://github.com/KhronosGroup/glTF/tree/master/extensions/2.0/Khronos/KHR_materials_pbrSpecularGlossiness + +Author: Matthew Matl +""" +import abc +import numpy as np +import six + +from .constants import TexFlags +from .utils import format_color_vector, format_texture_source +from .texture import Texture + + +@six.add_metaclass(abc.ABCMeta) +class Material(object): + """Base for standard glTF 2.0 materials. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + normalTexture : (n,n,3) float or :class:`Texture`, optional + A tangent space normal map. The texture contains RGB components in + linear space. Each texel represents the XYZ components of a normal + vector in tangent space. Red [0 to 255] maps to X [-1 to 1]. Green + [0 to 255] maps to Y [-1 to 1]. Blue [128 to 255] maps to Z + [1/255 to 1]. The normal vectors use OpenGL conventions where +X is + right and +Y is up. +Z points toward the viewer. + occlusionTexture : (n,n,1) float or :class:`Texture`, optional + The occlusion map texture. The occlusion values are sampled from the R + channel. Higher values indicate areas that should receive full indirect + lighting and lower values indicate no indirect lighting. These values + are linear. If other channels are present (GBA), they are ignored for + occlusion calculations. + emissiveTexture : (n,n,3) float or :class:`Texture`, optional + The emissive map controls the color and intensity of the light being + emitted by the material. This texture contains RGB components in sRGB + color space. If a fourth component (A) is present, it is ignored. + emissiveFactor : (3,) float, optional + The RGB components of the emissive color of the material. These values + are linear. If an emissiveTexture is specified, this value is + multiplied with the texel values. + alphaMode : str, optional + The material's alpha rendering mode enumeration specifying the + interpretation of the alpha value of the main factor and texture. + Allowed Values: + + - `"OPAQUE"` The alpha value is ignored and the rendered output is + fully opaque. + - `"MASK"` The rendered output is either fully opaque or fully + transparent depending on the alpha value and the specified alpha + cutoff value. + - `"BLEND"` The alpha value is used to composite the source and + destination areas. The rendered output is combined with the + background using the normal painting operation (i.e. the Porter + and Duff over operator). + + alphaCutoff : float, optional + Specifies the cutoff threshold when in MASK mode. If the alpha value is + greater than or equal to this value then it is rendered as fully + opaque, otherwise, it is rendered as fully transparent. + A value greater than 1.0 will render the entire material as fully + transparent. This value is ignored for other modes. + doubleSided : bool, optional + Specifies whether the material is double sided. When this value is + false, back-face culling is enabled. When this value is true, + back-face culling is disabled and double sided lighting is enabled. + smooth : bool, optional + If True, the material is rendered smoothly by using only one normal + per vertex and face indexing. + wireframe : bool, optional + If True, the material is rendered in wireframe mode. + """ + + def __init__(self, + name=None, + normalTexture=None, + occlusionTexture=None, + emissiveTexture=None, + emissiveFactor=None, + alphaMode=None, + alphaCutoff=None, + doubleSided=False, + smooth=True, + wireframe=False): + + # Set defaults + if alphaMode is None: + alphaMode = 'OPAQUE' + + if alphaCutoff is None: + alphaCutoff = 0.5 + + if emissiveFactor is None: + emissiveFactor = np.zeros(3).astype(np.float32) + + self.name = name + self.normalTexture = normalTexture + self.occlusionTexture = occlusionTexture + self.emissiveTexture = emissiveTexture + self.emissiveFactor = emissiveFactor + self.alphaMode = alphaMode + self.alphaCutoff = alphaCutoff + self.doubleSided = doubleSided + self.smooth = smooth + self.wireframe = wireframe + + self._tex_flags = None + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def normalTexture(self): + """(n,n,3) float or :class:`Texture` : The tangent-space normal map. + """ + return self._normalTexture + + @normalTexture.setter + def normalTexture(self, value): + # TODO TMP + self._normalTexture = self._format_texture(value, 'RGB') + self._tex_flags = None + + @property + def occlusionTexture(self): + """(n,n,1) float or :class:`Texture` : The ambient occlusion map. + """ + return self._occlusionTexture + + @occlusionTexture.setter + def occlusionTexture(self, value): + self._occlusionTexture = self._format_texture(value, 'R') + self._tex_flags = None + + @property + def emissiveTexture(self): + """(n,n,3) float or :class:`Texture` : The emission map. + """ + return self._emissiveTexture + + @emissiveTexture.setter + def emissiveTexture(self, value): + self._emissiveTexture = self._format_texture(value, 'RGB') + self._tex_flags = None + + @property + def emissiveFactor(self): + """(3,) float : Base multiplier for emission colors. + """ + return self._emissiveFactor + + @emissiveFactor.setter + def emissiveFactor(self, value): + if value is None: + value = np.zeros(3) + self._emissiveFactor = format_color_vector(value, 3) + + @property + def alphaMode(self): + """str : The mode for blending. + """ + return self._alphaMode + + @alphaMode.setter + def alphaMode(self, value): + if value not in set(['OPAQUE', 'MASK', 'BLEND']): + raise ValueError('Invalid alpha mode {}'.format(value)) + self._alphaMode = value + + @property + def alphaCutoff(self): + """float : The cutoff threshold in MASK mode. + """ + return self._alphaCutoff + + @alphaCutoff.setter + def alphaCutoff(self, value): + if value < 0 or value > 1: + raise ValueError('Alpha cutoff must be in range [0,1]') + self._alphaCutoff = float(value) + + @property + def doubleSided(self): + """bool : Whether the material is double-sided. + """ + return self._doubleSided + + @doubleSided.setter + def doubleSided(self, value): + if not isinstance(value, bool): + raise TypeError('Double sided must be a boolean value') + self._doubleSided = value + + @property + def smooth(self): + """bool : Whether to render the mesh smoothly by + interpolating vertex normals. + """ + return self._smooth + + @smooth.setter + def smooth(self, value): + if not isinstance(value, bool): + raise TypeError('Double sided must be a boolean value') + self._smooth = value + + @property + def wireframe(self): + """bool : Whether to render the mesh in wireframe mode. + """ + return self._wireframe + + @wireframe.setter + def wireframe(self, value): + if not isinstance(value, bool): + raise TypeError('Wireframe must be a boolean value') + self._wireframe = value + + @property + def is_transparent(self): + """bool : If True, the object is partially transparent. + """ + return self._compute_transparency() + + @property + def tex_flags(self): + """int : Texture availability flags. + """ + if self._tex_flags is None: + self._tex_flags = self._compute_tex_flags() + return self._tex_flags + + @property + def textures(self): + """list of :class:`Texture` : The textures associated with this + material. + """ + return self._compute_textures() + + def _compute_transparency(self): + return False + + def _compute_tex_flags(self): + tex_flags = TexFlags.NONE + if self.normalTexture is not None: + tex_flags |= TexFlags.NORMAL + if self.occlusionTexture is not None: + tex_flags |= TexFlags.OCCLUSION + if self.emissiveTexture is not None: + tex_flags |= TexFlags.EMISSIVE + return tex_flags + + def _compute_textures(self): + all_textures = [ + self.normalTexture, self.occlusionTexture, self.emissiveTexture + ] + textures = set([t for t in all_textures if t is not None]) + return textures + + def _format_texture(self, texture, target_channels='RGB'): + """Format a texture as a float32 np array. + """ + if isinstance(texture, Texture) or texture is None: + return texture + else: + source = format_texture_source(texture, target_channels) + return Texture(source=source, source_channels=target_channels) + + +class MetallicRoughnessMaterial(Material): + """A material based on the metallic-roughness material model from + Physically-Based Rendering (PBR) methodology. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + normalTexture : (n,n,3) float or :class:`Texture`, optional + A tangent space normal map. The texture contains RGB components in + linear space. Each texel represents the XYZ components of a normal + vector in tangent space. Red [0 to 255] maps to X [-1 to 1]. Green + [0 to 255] maps to Y [-1 to 1]. Blue [128 to 255] maps to Z + [1/255 to 1]. The normal vectors use OpenGL conventions where +X is + right and +Y is up. +Z points toward the viewer. + occlusionTexture : (n,n,1) float or :class:`Texture`, optional + The occlusion map texture. The occlusion values are sampled from the R + channel. Higher values indicate areas that should receive full indirect + lighting and lower values indicate no indirect lighting. These values + are linear. If other channels are present (GBA), they are ignored for + occlusion calculations. + emissiveTexture : (n,n,3) float or :class:`Texture`, optional + The emissive map controls the color and intensity of the light being + emitted by the material. This texture contains RGB components in sRGB + color space. If a fourth component (A) is present, it is ignored. + emissiveFactor : (3,) float, optional + The RGB components of the emissive color of the material. These values + are linear. If an emissiveTexture is specified, this value is + multiplied with the texel values. + alphaMode : str, optional + The material's alpha rendering mode enumeration specifying the + interpretation of the alpha value of the main factor and texture. + Allowed Values: + + - `"OPAQUE"` The alpha value is ignored and the rendered output is + fully opaque. + - `"MASK"` The rendered output is either fully opaque or fully + transparent depending on the alpha value and the specified alpha + cutoff value. + - `"BLEND"` The alpha value is used to composite the source and + destination areas. The rendered output is combined with the + background using the normal painting operation (i.e. the Porter + and Duff over operator). + + alphaCutoff : float, optional + Specifies the cutoff threshold when in MASK mode. If the alpha value is + greater than or equal to this value then it is rendered as fully + opaque, otherwise, it is rendered as fully transparent. + A value greater than 1.0 will render the entire material as fully + transparent. This value is ignored for other modes. + doubleSided : bool, optional + Specifies whether the material is double sided. When this value is + false, back-face culling is enabled. When this value is true, + back-face culling is disabled and double sided lighting is enabled. + smooth : bool, optional + If True, the material is rendered smoothly by using only one normal + per vertex and face indexing. + wireframe : bool, optional + If True, the material is rendered in wireframe mode. + baseColorFactor : (4,) float, optional + The RGBA components of the base color of the material. The fourth + component (A) is the alpha coverage of the material. The alphaMode + property specifies how alpha is interpreted. These values are linear. + If a baseColorTexture is specified, this value is multiplied with the + texel values. + baseColorTexture : (n,n,4) float or :class:`Texture`, optional + The base color texture. This texture contains RGB(A) components in sRGB + color space. The first three components (RGB) specify the base color of + the material. If the fourth component (A) is present, it represents the + alpha coverage of the material. Otherwise, an alpha of 1.0 is assumed. + The alphaMode property specifies how alpha is interpreted. + The stored texels must not be premultiplied. + metallicFactor : float + The metalness of the material. A value of 1.0 means the material is a + metal. A value of 0.0 means the material is a dielectric. Values in + between are for blending between metals and dielectrics such as dirty + metallic surfaces. This value is linear. If a metallicRoughnessTexture + is specified, this value is multiplied with the metallic texel values. + roughnessFactor : float + The roughness of the material. A value of 1.0 means the material is + completely rough. A value of 0.0 means the material is completely + smooth. This value is linear. If a metallicRoughnessTexture is + specified, this value is multiplied with the roughness texel values. + metallicRoughnessTexture : (n,n,2) float or :class:`Texture`, optional + The metallic-roughness texture. The metalness values are sampled from + the B channel. The roughness values are sampled from the G channel. + These values are linear. If other channels are present (R or A), they + are ignored for metallic-roughness calculations. + """ + + def __init__(self, + name=None, + normalTexture=None, + occlusionTexture=None, + emissiveTexture=None, + emissiveFactor=None, + alphaMode=None, + alphaCutoff=None, + doubleSided=False, + smooth=True, + wireframe=False, + baseColorFactor=None, + baseColorTexture=None, + metallicFactor=1.0, + roughnessFactor=1.0, + metallicRoughnessTexture=None): + super(MetallicRoughnessMaterial, self).__init__( + name=name, + normalTexture=normalTexture, + occlusionTexture=occlusionTexture, + emissiveTexture=emissiveTexture, + emissiveFactor=emissiveFactor, + alphaMode=alphaMode, + alphaCutoff=alphaCutoff, + doubleSided=doubleSided, + smooth=smooth, + wireframe=wireframe + ) + + # Set defaults + if baseColorFactor is None: + baseColorFactor = np.ones(4).astype(np.float32) + + self.baseColorFactor = baseColorFactor + self.baseColorTexture = baseColorTexture + self.metallicFactor = metallicFactor + self.roughnessFactor = roughnessFactor + self.metallicRoughnessTexture = metallicRoughnessTexture + + @property + def baseColorFactor(self): + """(4,) float or :class:`Texture` : The RGBA base color multiplier. + """ + return self._baseColorFactor + + @baseColorFactor.setter + def baseColorFactor(self, value): + if value is None: + value = np.ones(4) + self._baseColorFactor = format_color_vector(value, 4) + + @property + def baseColorTexture(self): + """(n,n,4) float or :class:`Texture` : The diffuse texture. + """ + return self._baseColorTexture + + @baseColorTexture.setter + def baseColorTexture(self, value): + self._baseColorTexture = self._format_texture(value, 'RGBA') + self._tex_flags = None + + @property + def metallicFactor(self): + """float : The metalness of the material. + """ + return self._metallicFactor + + @metallicFactor.setter + def metallicFactor(self, value): + if value is None: + value = 1.0 + if value < 0 or value > 1: + raise ValueError('Metallic factor must be in range [0,1]') + self._metallicFactor = float(value) + + @property + def roughnessFactor(self): + """float : The roughness of the material. + """ + return self.RoughnessFactor + + @roughnessFactor.setter + def roughnessFactor(self, value): + if value is None: + value = 1.0 + if value < 0 or value > 1: + raise ValueError('Roughness factor must be in range [0,1]') + self.RoughnessFactor = float(value) + + @property + def metallicRoughnessTexture(self): + """(n,n,2) float or :class:`Texture` : The metallic-roughness texture. + """ + return self._metallicRoughnessTexture + + @metallicRoughnessTexture.setter + def metallicRoughnessTexture(self, value): + self._metallicRoughnessTexture = self._format_texture(value, 'GB') + self._tex_flags = None + + def _compute_tex_flags(self): + tex_flags = super(MetallicRoughnessMaterial, self)._compute_tex_flags() + if self.baseColorTexture is not None: + tex_flags |= TexFlags.BASE_COLOR + if self.metallicRoughnessTexture is not None: + tex_flags |= TexFlags.METALLIC_ROUGHNESS + return tex_flags + + def _compute_transparency(self): + if self.alphaMode == 'OPAQUE': + return False + cutoff = self.alphaCutoff + if self.alphaMode == 'BLEND': + cutoff = 1.0 + if self.baseColorFactor[3] < cutoff: + return True + if (self.baseColorTexture is not None and + self.baseColorTexture.is_transparent(cutoff)): + return True + return False + + def _compute_textures(self): + textures = super(MetallicRoughnessMaterial, self)._compute_textures() + all_textures = [self.baseColorTexture, self.metallicRoughnessTexture] + all_textures = {t for t in all_textures if t is not None} + textures |= all_textures + return textures + + +class SpecularGlossinessMaterial(Material): + """A material based on the specular-glossiness material model from + Physically-Based Rendering (PBR) methodology. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + normalTexture : (n,n,3) float or :class:`Texture`, optional + A tangent space normal map. The texture contains RGB components in + linear space. Each texel represents the XYZ components of a normal + vector in tangent space. Red [0 to 255] maps to X [-1 to 1]. Green + [0 to 255] maps to Y [-1 to 1]. Blue [128 to 255] maps to Z + [1/255 to 1]. The normal vectors use OpenGL conventions where +X is + right and +Y is up. +Z points toward the viewer. + occlusionTexture : (n,n,1) float or :class:`Texture`, optional + The occlusion map texture. The occlusion values are sampled from the R + channel. Higher values indicate areas that should receive full indirect + lighting and lower values indicate no indirect lighting. These values + are linear. If other channels are present (GBA), they are ignored for + occlusion calculations. + emissiveTexture : (n,n,3) float or :class:`Texture`, optional + The emissive map controls the color and intensity of the light being + emitted by the material. This texture contains RGB components in sRGB + color space. If a fourth component (A) is present, it is ignored. + emissiveFactor : (3,) float, optional + The RGB components of the emissive color of the material. These values + are linear. If an emissiveTexture is specified, this value is + multiplied with the texel values. + alphaMode : str, optional + The material's alpha rendering mode enumeration specifying the + interpretation of the alpha value of the main factor and texture. + Allowed Values: + + - `"OPAQUE"` The alpha value is ignored and the rendered output is + fully opaque. + - `"MASK"` The rendered output is either fully opaque or fully + transparent depending on the alpha value and the specified alpha + cutoff value. + - `"BLEND"` The alpha value is used to composite the source and + destination areas. The rendered output is combined with the + background using the normal painting operation (i.e. the Porter + and Duff over operator). + + alphaCutoff : float, optional + Specifies the cutoff threshold when in MASK mode. If the alpha value is + greater than or equal to this value then it is rendered as fully + opaque, otherwise, it is rendered as fully transparent. + A value greater than 1.0 will render the entire material as fully + transparent. This value is ignored for other modes. + doubleSided : bool, optional + Specifies whether the material is double sided. When this value is + false, back-face culling is enabled. When this value is true, + back-face culling is disabled and double sided lighting is enabled. + smooth : bool, optional + If True, the material is rendered smoothly by using only one normal + per vertex and face indexing. + wireframe : bool, optional + If True, the material is rendered in wireframe mode. + diffuseFactor : (4,) float + The RGBA components of the reflected diffuse color of the material. + Metals have a diffuse value of [0.0, 0.0, 0.0]. The fourth component + (A) is the opacity of the material. The values are linear. + diffuseTexture : (n,n,4) float or :class:`Texture`, optional + The diffuse texture. This texture contains RGB(A) components of the + reflected diffuse color of the material in sRGB color space. If the + fourth component (A) is present, it represents the alpha coverage of + the material. Otherwise, an alpha of 1.0 is assumed. + The alphaMode property specifies how alpha is interpreted. + The stored texels must not be premultiplied. + specularFactor : (3,) float + The specular RGB color of the material. This value is linear. + glossinessFactor : float + The glossiness or smoothness of the material. A value of 1.0 means the + material has full glossiness or is perfectly smooth. A value of 0.0 + means the material has no glossiness or is perfectly rough. This value + is linear. + specularGlossinessTexture : (n,n,4) or :class:`Texture`, optional + The specular-glossiness texture is a RGBA texture, containing the + specular color (RGB) in sRGB space and the glossiness value (A) in + linear space. + """ + + def __init__(self, + name=None, + normalTexture=None, + occlusionTexture=None, + emissiveTexture=None, + emissiveFactor=None, + alphaMode=None, + alphaCutoff=None, + doubleSided=False, + smooth=True, + wireframe=False, + diffuseFactor=None, + diffuseTexture=None, + specularFactor=None, + glossinessFactor=1.0, + specularGlossinessTexture=None): + super(SpecularGlossinessMaterial, self).__init__( + name=name, + normalTexture=normalTexture, + occlusionTexture=occlusionTexture, + emissiveTexture=emissiveTexture, + emissiveFactor=emissiveFactor, + alphaMode=alphaMode, + alphaCutoff=alphaCutoff, + doubleSided=doubleSided, + smooth=smooth, + wireframe=wireframe + ) + + # Set defaults + if diffuseFactor is None: + diffuseFactor = np.ones(4).astype(np.float32) + if specularFactor is None: + specularFactor = np.ones(3).astype(np.float32) + + self.diffuseFactor = diffuseFactor + self.diffuseTexture = diffuseTexture + self.specularFactor = specularFactor + self.glossinessFactor = glossinessFactor + self.specularGlossinessTexture = specularGlossinessTexture + + @property + def diffuseFactor(self): + """(4,) float : The diffuse base color. + """ + return self._diffuseFactor + + @diffuseFactor.setter + def diffuseFactor(self, value): + self._diffuseFactor = format_color_vector(value, 4) + + @property + def diffuseTexture(self): + """(n,n,4) float or :class:`Texture` : The diffuse map. + """ + return self._diffuseTexture + + @diffuseTexture.setter + def diffuseTexture(self, value): + self._diffuseTexture = self._format_texture(value, 'RGBA') + self._tex_flags = None + + @property + def specularFactor(self): + """(3,) float : The specular color of the material. + """ + return self._specularFactor + + @specularFactor.setter + def specularFactor(self, value): + self._specularFactor = format_color_vector(value, 3) + + @property + def glossinessFactor(self): + """float : The glossiness of the material. + """ + return self.glossinessFactor + + @glossinessFactor.setter + def glossinessFactor(self, value): + if value < 0 or value > 1: + raise ValueError('glossiness factor must be in range [0,1]') + self._glossinessFactor = float(value) + + @property + def specularGlossinessTexture(self): + """(n,n,4) or :class:`Texture` : The specular-glossiness texture. + """ + return self._specularGlossinessTexture + + @specularGlossinessTexture.setter + def specularGlossinessTexture(self, value): + self._specularGlossinessTexture = self._format_texture(value, 'GB') + self._tex_flags = None + + def _compute_tex_flags(self): + flags = super(SpecularGlossinessMaterial, self)._compute_tex_flags() + if self.diffuseTexture is not None: + flags |= TexFlags.DIFFUSE + if self.specularGlossinessTexture is not None: + flags |= TexFlags.SPECULAR_GLOSSINESS + return flags + + def _compute_transparency(self): + if self.alphaMode == 'OPAQUE': + return False + cutoff = self.alphaCutoff + if self.alphaMode == 'BLEND': + cutoff = 1.0 + if self.diffuseFactor[3] < cutoff: + return True + if (self.diffuseTexture is not None and + self.diffuseTexture.is_transparent(cutoff)): + return True + return False + + def _compute_textures(self): + textures = super(SpecularGlossinessMaterial, self)._compute_textures() + all_textures = [self.diffuseTexture, self.specularGlossinessTexture] + all_textures = {t for t in all_textures if t is not None} + textures |= all_textures + return textures diff --git a/pyrender/pyrender/mesh.py b/pyrender/pyrender/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..36833ea3dfa6c095a18fc745ff34cf106e83c95d --- /dev/null +++ b/pyrender/pyrender/mesh.py @@ -0,0 +1,328 @@ +"""Meshes, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-mesh + +Author: Matthew Matl +""" +import copy + +import numpy as np +import trimesh + +from .primitive import Primitive +from .constants import GLTF +from .material import MetallicRoughnessMaterial + + +class Mesh(object): + """A set of primitives to be rendered. + + Parameters + ---------- + name : str + The user-defined name of this object. + primitives : list of :class:`Primitive` + The primitives associated with this mesh. + weights : (k,) float + Array of weights to be applied to the Morph Targets. + is_visible : bool + If False, the mesh will not be rendered. + """ + + def __init__(self, primitives, name=None, weights=None, is_visible=True): + self.primitives = primitives + self.name = name + self.weights = weights + self.is_visible = is_visible + + self._bounds = None + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def primitives(self): + """list of :class:`Primitive` : The primitives associated + with this mesh. + """ + return self._primitives + + @primitives.setter + def primitives(self, value): + self._primitives = value + + @property + def weights(self): + """(k,) float : Weights to be applied to morph targets. + """ + return self._weights + + @weights.setter + def weights(self, value): + self._weights = value + + @property + def is_visible(self): + """bool : Whether the mesh is visible. + """ + return self._is_visible + + @is_visible.setter + def is_visible(self, value): + self._is_visible = value + + @property + def bounds(self): + """(2,3) float : The axis-aligned bounds of the mesh. + """ + if self._bounds is None: + bounds = np.array([[np.infty, np.infty, np.infty], + [-np.infty, -np.infty, -np.infty]]) + for p in self.primitives: + bounds[0] = np.minimum(bounds[0], p.bounds[0]) + bounds[1] = np.maximum(bounds[1], p.bounds[1]) + self._bounds = bounds + return self._bounds + + @property + def centroid(self): + """(3,) float : The centroid of the mesh's axis-aligned bounding box + (AABB). + """ + return np.mean(self.bounds, axis=0) + + @property + def extents(self): + """(3,) float : The lengths of the axes of the mesh's AABB. + """ + return np.diff(self.bounds, axis=0).reshape(-1) + + @property + def scale(self): + """(3,) float : The length of the diagonal of the mesh's AABB. + """ + return np.linalg.norm(self.extents) + + @property + def is_transparent(self): + """bool : If True, the mesh is partially-transparent. + """ + for p in self.primitives: + if p.is_transparent: + return True + return False + + @staticmethod + def from_points(points, colors=None, normals=None, + is_visible=True, poses=None): + """Create a Mesh from a set of points. + + Parameters + ---------- + points : (n,3) float + The point positions. + colors : (n,3) or (n,4) float, optional + RGB or RGBA colors for each point. + normals : (n,3) float, optionals + The normal vectors for each point. + is_visible : bool + If False, the points will not be rendered. + poses : (x,4,4) + Array of 4x4 transformation matrices for instancing this object. + + Returns + ------- + mesh : :class:`Mesh` + The created mesh. + """ + primitive = Primitive( + positions=points, + normals=normals, + color_0=colors, + mode=GLTF.POINTS, + poses=poses + ) + mesh = Mesh(primitives=[primitive], is_visible=is_visible) + return mesh + + @staticmethod + def from_trimesh(mesh, material=None, is_visible=True, + poses=None, wireframe=False, smooth=True): + """Create a Mesh from a :class:`~trimesh.base.Trimesh`. + + Parameters + ---------- + mesh : :class:`~trimesh.base.Trimesh` or list of them + A triangular mesh or a list of meshes. + material : :class:`Material` + The material of the object. Overrides any mesh material. + If not specified and the mesh has no material, a default material + will be used. + is_visible : bool + If False, the mesh will not be rendered. + poses : (n,4,4) float + Array of 4x4 transformation matrices for instancing this object. + wireframe : bool + If `True`, the mesh will be rendered as a wireframe object + smooth : bool + If `True`, the mesh will be rendered with interpolated vertex + normals. Otherwise, the mesh edges will stay sharp. + + Returns + ------- + mesh : :class:`Mesh` + The created mesh. + """ + + if isinstance(mesh, (list, tuple, set, np.ndarray)): + meshes = list(mesh) + elif isinstance(mesh, trimesh.Trimesh): + meshes = [mesh] + else: + raise TypeError('Expected a Trimesh or a list, got a {}' + .format(type(mesh))) + + primitives = [] + for m in meshes: + positions = None + normals = None + indices = None + + # Compute positions, normals, and indices + if smooth: + positions = m.vertices.copy() + normals = m.vertex_normals.copy() + indices = m.faces.copy() + else: + positions = m.vertices[m.faces].reshape((3 * len(m.faces), 3)) + normals = np.repeat(m.face_normals, 3, axis=0) + + # Compute colors, texture coords, and material properties + color_0, texcoord_0, primitive_material = Mesh._get_trimesh_props(m, smooth=smooth, material=material) + + # Override if material is given. + if material is not None: + #primitive_material = copy.copy(material) + primitive_material = copy.deepcopy(material) # TODO + + if primitive_material is None: + # Replace material with default if needed + primitive_material = MetallicRoughnessMaterial( + alphaMode='BLEND', + baseColorFactor=[0.3, 0.3, 0.3, 1.0], + metallicFactor=0.2, + roughnessFactor=0.8 + ) + + primitive_material.wireframe = wireframe + + # Create the primitive + primitives.append(Primitive( + positions=positions, + normals=normals, + texcoord_0=texcoord_0, + color_0=color_0, + indices=indices, + material=primitive_material, + mode=GLTF.TRIANGLES, + poses=poses + )) + + return Mesh(primitives=primitives, is_visible=is_visible) + + @staticmethod + def _get_trimesh_props(mesh, smooth=False, material=None): + """Gets the vertex colors, texture coordinates, and material properties + from a :class:`~trimesh.base.Trimesh`. + """ + colors = None + texcoords = None + + # If the trimesh visual is undefined, return none for both + if not mesh.visual.defined: + return colors, texcoords, material + + # Process vertex colors + if material is None: + if mesh.visual.kind == 'vertex': + vc = mesh.visual.vertex_colors.copy() + if smooth: + colors = vc + else: + colors = vc[mesh.faces].reshape( + (3 * len(mesh.faces), vc.shape[1]) + ) + material = MetallicRoughnessMaterial( + alphaMode='BLEND', + baseColorFactor=[1.0, 1.0, 1.0, 1.0], + metallicFactor=0.2, + roughnessFactor=0.8 + ) + # Process face colors + elif mesh.visual.kind == 'face': + if smooth: + raise ValueError('Cannot use face colors with a smooth mesh') + else: + colors = np.repeat(mesh.visual.face_colors, 3, axis=0) + + material = MetallicRoughnessMaterial( + alphaMode='BLEND', + baseColorFactor=[1.0, 1.0, 1.0, 1.0], + metallicFactor=0.2, + roughnessFactor=0.8 + ) + + # Process texture colors + if mesh.visual.kind == 'texture': + # Configure UV coordinates + if mesh.visual.uv is not None and len(mesh.visual.uv) != 0: + uv = mesh.visual.uv.copy() + if smooth: + texcoords = uv + else: + texcoords = uv[mesh.faces].reshape( + (3 * len(mesh.faces), uv.shape[1]) + ) + + if material is None: + # Configure mesh material + mat = mesh.visual.material + + if isinstance(mat, trimesh.visual.texture.PBRMaterial): + material = MetallicRoughnessMaterial( + normalTexture=mat.normalTexture, + occlusionTexture=mat.occlusionTexture, + emissiveTexture=mat.emissiveTexture, + emissiveFactor=mat.emissiveFactor, + alphaMode='BLEND', + baseColorFactor=mat.baseColorFactor, + baseColorTexture=mat.baseColorTexture, + metallicFactor=mat.metallicFactor, + roughnessFactor=mat.roughnessFactor, + metallicRoughnessTexture=mat.metallicRoughnessTexture, + doubleSided=mat.doubleSided, + alphaCutoff=mat.alphaCutoff + ) + elif isinstance(mat, trimesh.visual.texture.SimpleMaterial): + glossiness = mat.kwargs.get('Ns', 1.0) + if isinstance(glossiness, list): + glossiness = float(glossiness[0]) + roughness = (2 / (glossiness + 2)) ** (1.0 / 4.0) + material = MetallicRoughnessMaterial( + alphaMode='BLEND', + roughnessFactor=roughness, + baseColorFactor=mat.diffuse, + baseColorTexture=mat.image, + ) + elif isinstance(mat, MetallicRoughnessMaterial): + material = mat + + return colors, texcoords, material diff --git a/pyrender/pyrender/node.py b/pyrender/pyrender/node.py new file mode 100644 index 0000000000000000000000000000000000000000..1f37f7856cc732a37dc58253022a7c331489493e --- /dev/null +++ b/pyrender/pyrender/node.py @@ -0,0 +1,263 @@ +"""Nodes, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-node + +Author: Matthew Matl +""" +import numpy as np + +import trimesh.transformations as transformations + +from .camera import Camera +from .mesh import Mesh +from .light import Light + + +class Node(object): + """A node in the node hierarchy. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + camera : :class:`Camera`, optional + The camera in this node. + children : list of :class:`Node` + The children of this node. + skin : int, optional + The index of the skin referenced by this node. + matrix : (4,4) float, optional + A floating-point 4x4 transformation matrix. + mesh : :class:`Mesh`, optional + The mesh in this node. + rotation : (4,) float, optional + The node's unit quaternion in the order (x, y, z, w), where + w is the scalar. + scale : (3,) float, optional + The node's non-uniform scale, given as the scaling factors along the x, + y, and z axes. + translation : (3,) float, optional + The node's translation along the x, y, and z axes. + weights : (n,) float + The weights of the instantiated Morph Target. Number of elements must + match number of Morph Targets of used mesh. + light : :class:`Light`, optional + The light in this node. + """ + + def __init__(self, + name=None, + camera=None, + children=None, + skin=None, + matrix=None, + mesh=None, + rotation=None, + scale=None, + translation=None, + weights=None, + light=None): + # Set defaults + if children is None: + children = [] + + self._matrix = None + self._scale = None + self._rotation = None + self._translation = None + if matrix is None: + if rotation is None: + rotation = np.array([0.0, 0.0, 0.0, 1.0]) + if translation is None: + translation = np.zeros(3) + if scale is None: + scale = np.ones(3) + self.rotation = rotation + self.translation = translation + self.scale = scale + else: + self.matrix = matrix + + self.name = name + self.camera = camera + self.children = children + self.skin = skin + self.mesh = mesh + self.weights = weights + self.light = light + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def camera(self): + """:class:`Camera` : The camera in this node. + """ + return self._camera + + @camera.setter + def camera(self, value): + if value is not None and not isinstance(value, Camera): + raise TypeError('Value must be a camera') + self._camera = value + + @property + def children(self): + """list of :class:`Node` : The children of this node. + """ + return self._children + + @children.setter + def children(self, value): + self._children = value + + @property + def skin(self): + """int : The skin index for this node. + """ + return self._skin + + @skin.setter + def skin(self, value): + self._skin = value + + @property + def mesh(self): + """:class:`Mesh` : The mesh in this node. + """ + return self._mesh + + @mesh.setter + def mesh(self, value): + if value is not None and not isinstance(value, Mesh): + raise TypeError('Value must be a mesh') + self._mesh = value + + @property + def light(self): + """:class:`Light` : The light in this node. + """ + return self._light + + @light.setter + def light(self, value): + if value is not None and not isinstance(value, Light): + raise TypeError('Value must be a light') + self._light = value + + @property + def rotation(self): + """(4,) float : The xyzw quaternion for this node. + """ + return self._rotation + + @rotation.setter + def rotation(self, value): + value = np.asanyarray(value) + if value.shape != (4,): + raise ValueError('Quaternion must be a (4,) vector') + if np.abs(np.linalg.norm(value) - 1.0) > 1e-3: + raise ValueError('Quaternion must have norm == 1.0') + self._rotation = value + self._matrix = None + + @property + def translation(self): + """(3,) float : The translation for this node. + """ + return self._translation + + @translation.setter + def translation(self, value): + value = np.asanyarray(value) + if value.shape != (3,): + raise ValueError('Translation must be a (3,) vector') + self._translation = value + self._matrix = None + + @property + def scale(self): + """(3,) float : The scale for this node. + """ + return self._scale + + @scale.setter + def scale(self, value): + value = np.asanyarray(value) + if value.shape != (3,): + raise ValueError('Scale must be a (3,) vector') + self._scale = value + self._matrix = None + + @property + def matrix(self): + """(4,4) float : The homogenous transform matrix for this node. + + Note that this matrix's elements are not settable, + it's just a copy of the internal matrix. You can set the whole + matrix, but not an individual element. + """ + if self._matrix is None: + self._matrix = self._m_from_tqs( + self.translation, self.rotation, self.scale + ) + return self._matrix.copy() + + @matrix.setter + def matrix(self, value): + value = np.asanyarray(value) + if value.shape != (4,4): + raise ValueError('Matrix must be a 4x4 numpy ndarray') + if not np.allclose(value[3,:], np.array([0.0, 0.0, 0.0, 1.0])): + raise ValueError('Bottom row of matrix must be [0,0,0,1]') + self.rotation = Node._q_from_m(value) + self.scale = Node._s_from_m(value) + self.translation = Node._t_from_m(value) + self._matrix = value + + @staticmethod + def _t_from_m(m): + return m[:3,3] + + @staticmethod + def _r_from_m(m): + U = m[:3,:3] + norms = np.linalg.norm(U.T, axis=1) + return U / norms + + @staticmethod + def _q_from_m(m): + M = np.eye(4) + M[:3,:3] = Node._r_from_m(m) + q_wxyz = transformations.quaternion_from_matrix(M) + return np.roll(q_wxyz, -1) + + @staticmethod + def _s_from_m(m): + return np.linalg.norm(m[:3,:3].T, axis=1) + + @staticmethod + def _r_from_q(q): + q_wxyz = np.roll(q, 1) + return transformations.quaternion_matrix(q_wxyz)[:3,:3] + + @staticmethod + def _m_from_tqs(t, q, s): + S = np.eye(4) + S[:3,:3] = np.diag(s) + + R = np.eye(4) + R[:3,:3] = Node._r_from_q(q) + + T = np.eye(4) + T[:3,3] = t + + return T.dot(R.dot(S)) diff --git a/pyrender/pyrender/offscreen.py b/pyrender/pyrender/offscreen.py new file mode 100644 index 0000000000000000000000000000000000000000..340142983006cdc6f51b6d114e9b2b294aa4a919 --- /dev/null +++ b/pyrender/pyrender/offscreen.py @@ -0,0 +1,160 @@ +"""Wrapper for offscreen rendering. + +Author: Matthew Matl +""" +import os + +from .renderer import Renderer +from .constants import RenderFlags + + +class OffscreenRenderer(object): + """A wrapper for offscreen rendering. + + Parameters + ---------- + viewport_width : int + The width of the main viewport, in pixels. + viewport_height : int + The height of the main viewport, in pixels. + point_size : float + The size of screen-space points in pixels. + """ + + def __init__(self, viewport_width, viewport_height, point_size=1.0): + self.viewport_width = viewport_width + self.viewport_height = viewport_height + self.point_size = point_size + + self._platform = None + self._renderer = None + self._create() + + @property + def viewport_width(self): + """int : The width of the main viewport, in pixels. + """ + return self._viewport_width + + @viewport_width.setter + def viewport_width(self, value): + self._viewport_width = int(value) + + @property + def viewport_height(self): + """int : The height of the main viewport, in pixels. + """ + return self._viewport_height + + @viewport_height.setter + def viewport_height(self, value): + self._viewport_height = int(value) + + @property + def point_size(self): + """float : The pixel size of points in point clouds. + """ + return self._point_size + + @point_size.setter + def point_size(self, value): + self._point_size = float(value) + + def render(self, scene, flags=RenderFlags.NONE, seg_node_map=None): + """Render a scene with the given set of flags. + + Parameters + ---------- + scene : :class:`Scene` + A scene to render. + flags : int + A bitwise or of one or more flags from :class:`.RenderFlags`. + seg_node_map : dict + A map from :class:`.Node` objects to (3,) colors for each. + If specified along with flags set to :attr:`.RenderFlags.SEG`, + the color image will be a segmentation image. + + Returns + ------- + color_im : (h, w, 3) uint8 or (h, w, 4) uint8 + The color buffer in RGB format, or in RGBA format if + :attr:`.RenderFlags.RGBA` is set. + Not returned if flags includes :attr:`.RenderFlags.DEPTH_ONLY`. + depth_im : (h, w) float32 + The depth buffer in linear units. + """ + self._platform.make_current() + # If platform does not support dynamically-resizing framebuffers, + # destroy it and restart it + if (self._platform.viewport_height != self.viewport_height or + self._platform.viewport_width != self.viewport_width): + if not self._platform.supports_framebuffers(): + self.delete() + self._create() + + self._platform.make_current() + self._renderer.viewport_width = self.viewport_width + self._renderer.viewport_height = self.viewport_height + self._renderer.point_size = self.point_size + + if self._platform.supports_framebuffers(): + flags |= RenderFlags.OFFSCREEN + retval = self._renderer.render(scene, flags, seg_node_map) + else: + self._renderer.render(scene, flags, seg_node_map) + depth = self._renderer.read_depth_buf() + if flags & RenderFlags.DEPTH_ONLY: + retval = depth + else: + color = self._renderer.read_color_buf() + retval = color, depth + + # Make the platform not current + self._platform.make_uncurrent() + return retval + + def delete(self): + """Free all OpenGL resources. + """ + self._platform.make_current() + self._renderer.delete() + self._platform.delete_context() + del self._renderer + del self._platform + self._renderer = None + self._platform = None + import gc + gc.collect() + + def _create(self): + if 'PYOPENGL_PLATFORM' not in os.environ: + from pyrender.platforms.pyglet_platform import PygletPlatform + self._platform = PygletPlatform(self.viewport_width, + self.viewport_height) + elif os.environ['PYOPENGL_PLATFORM'] == 'egl': + from pyrender.platforms import egl + device_id = int(os.environ.get('EGL_DEVICE_ID', '0')) + egl_device = egl.get_device_by_index(device_id) + self._platform = egl.EGLPlatform(self.viewport_width, + self.viewport_height, + device=egl_device) + elif os.environ['PYOPENGL_PLATFORM'] == 'osmesa': + from pyrender.platforms.osmesa import OSMesaPlatform + self._platform = OSMesaPlatform(self.viewport_width, + self.viewport_height) + else: + raise ValueError('Unsupported PyOpenGL platform: {}'.format( + os.environ['PYOPENGL_PLATFORM'] + )) + self._platform.init_context() + self._platform.make_current() + self._renderer = Renderer(self.viewport_width, self.viewport_height) + + def __del__(self): + try: + self.delete() + except Exception: + pass + + +__all__ = ['OffscreenRenderer'] diff --git a/pyrender/pyrender/platforms/__init__.py b/pyrender/pyrender/platforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7837fd5fdeccab5e48c85e41d20b238ea7396599 --- /dev/null +++ b/pyrender/pyrender/platforms/__init__.py @@ -0,0 +1,6 @@ +"""Platforms for generating offscreen OpenGL contexts for rendering. + +Author: Matthew Matl +""" + +from .base import Platform diff --git a/pyrender/pyrender/platforms/base.py b/pyrender/pyrender/platforms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ecda906145e239737901809aa59db8d3e231c6 --- /dev/null +++ b/pyrender/pyrender/platforms/base.py @@ -0,0 +1,76 @@ +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class Platform(object): + """Base class for all OpenGL platforms. + + Parameters + ---------- + viewport_width : int + The width of the main viewport, in pixels. + viewport_height : int + The height of the main viewport, in pixels + """ + + def __init__(self, viewport_width, viewport_height): + self.viewport_width = viewport_width + self.viewport_height = viewport_height + + @property + def viewport_width(self): + """int : The width of the main viewport, in pixels. + """ + return self._viewport_width + + @viewport_width.setter + def viewport_width(self, value): + self._viewport_width = value + + @property + def viewport_height(self): + """int : The height of the main viewport, in pixels. + """ + return self._viewport_height + + @viewport_height.setter + def viewport_height(self, value): + self._viewport_height = value + + @abc.abstractmethod + def init_context(self): + """Create an OpenGL context. + """ + pass + + @abc.abstractmethod + def make_current(self): + """Make the OpenGL context current. + """ + pass + + @abc.abstractmethod + def make_uncurrent(self): + """Make the OpenGL context uncurrent. + """ + pass + + @abc.abstractmethod + def delete_context(self): + """Delete the OpenGL context. + """ + pass + + @abc.abstractmethod + def supports_framebuffers(self): + """Returns True if the method supports framebuffer rendering. + """ + pass + + def __del__(self): + try: + self.delete_context() + except Exception: + pass diff --git a/pyrender/pyrender/platforms/egl.py b/pyrender/pyrender/platforms/egl.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2478d29c9a538c53ad83fa31f8e2277cd897c8 --- /dev/null +++ b/pyrender/pyrender/platforms/egl.py @@ -0,0 +1,219 @@ +import ctypes +import os + +import OpenGL.platform + +from .base import Platform + +EGL_PLATFORM_DEVICE_EXT = 0x313F +EGL_DRM_DEVICE_FILE_EXT = 0x3233 + + +def _ensure_egl_loaded(): + plugin = OpenGL.platform.PlatformPlugin.by_name('egl') + if plugin is None: + raise RuntimeError("EGL platform plugin is not available.") + + plugin_class = plugin.load() + plugin.loaded = True + # create instance of this platform implementation + plugin = plugin_class() + + plugin.install(vars(OpenGL.platform)) + + +_ensure_egl_loaded() +from OpenGL import EGL as egl + + +def _get_egl_func(func_name, res_type, *arg_types): + address = egl.eglGetProcAddress(func_name) + if address is None: + return None + + proto = ctypes.CFUNCTYPE(res_type) + proto.argtypes = arg_types + func = proto(address) + return func + + +def _get_egl_struct(struct_name): + from OpenGL._opaque import opaque_pointer_cls + return opaque_pointer_cls(struct_name) + + +# These are not defined in PyOpenGL by default. +_EGLDeviceEXT = _get_egl_struct('EGLDeviceEXT') +_eglGetPlatformDisplayEXT = _get_egl_func('eglGetPlatformDisplayEXT', egl.EGLDisplay) +_eglQueryDevicesEXT = _get_egl_func('eglQueryDevicesEXT', egl.EGLBoolean) +_eglQueryDeviceStringEXT = _get_egl_func('eglQueryDeviceStringEXT', ctypes.c_char_p) + + +def query_devices(): + if _eglQueryDevicesEXT is None: + raise RuntimeError("EGL query extension is not loaded or is not supported.") + + num_devices = egl.EGLint() + success = _eglQueryDevicesEXT(0, None, ctypes.pointer(num_devices)) + if not success or num_devices.value < 1: + return [] + + devices = (_EGLDeviceEXT * num_devices.value)() # array of size num_devices + success = _eglQueryDevicesEXT(num_devices.value, devices, ctypes.pointer(num_devices)) + if not success or num_devices.value < 1: + return [] + + return [EGLDevice(devices[i]) for i in range(num_devices.value)] + + +def get_default_device(): + # Fall back to not using query extension. + if _eglQueryDevicesEXT is None: + return EGLDevice(None) + + return query_devices()[0] + + +def get_device_by_index(device_id): + if _eglQueryDevicesEXT is None and device_id == 0: + return get_default_device() + + devices = query_devices() + if device_id >= len(devices): + raise ValueError('Invalid device ID ({})'.format(device_id, len(devices))) + return devices[device_id] + + +class EGLDevice: + + def __init__(self, display=None): + self._display = display + + def get_display(self): + if self._display is None: + return egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY) + + return _eglGetPlatformDisplayEXT(EGL_PLATFORM_DEVICE_EXT, self._display, None) + + @property + def name(self): + if self._display is None: + return 'default' + + name = _eglQueryDeviceStringEXT(self._display, EGL_DRM_DEVICE_FILE_EXT) + if name is None: + return None + + return name.decode('ascii') + + def __repr__(self): + return "".format(self.name) + + +class EGLPlatform(Platform): + """Renders using EGL. + """ + + def __init__(self, viewport_width, viewport_height, device: EGLDevice = None): + super(EGLPlatform, self).__init__(viewport_width, viewport_height) + if device is None: + device = get_default_device() + + self._egl_device = device + self._egl_display = None + self._egl_context = None + + def init_context(self): + _ensure_egl_loaded() + + from OpenGL.EGL import ( + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_BLUE_SIZE, + EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_DEPTH_SIZE, + EGL_COLOR_BUFFER_TYPE, EGL_RGB_BUFFER, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, EGL_CONFORMANT, + EGL_NONE, EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, + EGL_OPENGL_API, EGL_CONTEXT_MAJOR_VERSION, + EGL_CONTEXT_MINOR_VERSION, + EGL_CONTEXT_OPENGL_PROFILE_MASK, + EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, + eglGetDisplay, eglInitialize, eglChooseConfig, + eglBindAPI, eglCreateContext, EGLConfig + ) + from OpenGL import arrays + + config_attributes = arrays.GLintArray.asArray([ + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, + EGL_BLUE_SIZE, 8, + EGL_RED_SIZE, 8, + EGL_GREEN_SIZE, 8, + EGL_DEPTH_SIZE, 24, + EGL_COLOR_BUFFER_TYPE, EGL_RGB_BUFFER, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_CONFORMANT, EGL_OPENGL_BIT, + EGL_NONE + ]) + context_attributes = arrays.GLintArray.asArray([ + EGL_CONTEXT_MAJOR_VERSION, 4, + EGL_CONTEXT_MINOR_VERSION, 1, + EGL_CONTEXT_OPENGL_PROFILE_MASK, + EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, + EGL_NONE + ]) + major, minor = ctypes.c_long(), ctypes.c_long() + num_configs = ctypes.c_long() + configs = (EGLConfig * 1)() + + # Cache DISPLAY if necessary and get an off-screen EGL display + orig_dpy = None + if 'DISPLAY' in os.environ: + orig_dpy = os.environ['DISPLAY'] + del os.environ['DISPLAY'] + + self._egl_display = self._egl_device.get_display() + if orig_dpy is not None: + os.environ['DISPLAY'] = orig_dpy + + # Initialize EGL + assert eglInitialize(self._egl_display, major, minor) + assert eglChooseConfig( + self._egl_display, config_attributes, configs, 1, num_configs + ) + + # Bind EGL to the OpenGL API + assert eglBindAPI(EGL_OPENGL_API) + + # Create an EGL context + self._egl_context = eglCreateContext( + self._egl_display, configs[0], + EGL_NO_CONTEXT, context_attributes + ) + + # Make it current + self.make_current() + + def make_current(self): + from OpenGL.EGL import eglMakeCurrent, EGL_NO_SURFACE + assert eglMakeCurrent( + self._egl_display, EGL_NO_SURFACE, EGL_NO_SURFACE, + self._egl_context + ) + + def make_uncurrent(self): + """Make the OpenGL context uncurrent. + """ + pass + + def delete_context(self): + from OpenGL.EGL import eglDestroyContext, eglTerminate + if self._egl_display is not None: + if self._egl_context is not None: + eglDestroyContext(self._egl_display, self._egl_context) + self._egl_context = None + eglTerminate(self._egl_display) + self._egl_display = None + + def supports_framebuffers(self): + return True + + +__all__ = ['EGLPlatform'] diff --git a/pyrender/pyrender/platforms/osmesa.py b/pyrender/pyrender/platforms/osmesa.py new file mode 100644 index 0000000000000000000000000000000000000000..deaa5ff44031a107883913ae9a18fc425d650f3d --- /dev/null +++ b/pyrender/pyrender/platforms/osmesa.py @@ -0,0 +1,59 @@ +from .base import Platform + + +__all__ = ['OSMesaPlatform'] + + +class OSMesaPlatform(Platform): + """Renders into a software buffer using OSMesa. Requires special versions + of OSMesa to be installed, plus PyOpenGL upgrade. + """ + + def __init__(self, viewport_width, viewport_height): + super(OSMesaPlatform, self).__init__(viewport_width, viewport_height) + self._context = None + self._buffer = None + + def init_context(self): + from OpenGL import arrays + from OpenGL.osmesa import ( + OSMesaCreateContextAttribs, OSMESA_FORMAT, + OSMESA_RGBA, OSMESA_PROFILE, OSMESA_CORE_PROFILE, + OSMESA_CONTEXT_MAJOR_VERSION, OSMESA_CONTEXT_MINOR_VERSION, + OSMESA_DEPTH_BITS + ) + + attrs = arrays.GLintArray.asArray([ + OSMESA_FORMAT, OSMESA_RGBA, + OSMESA_DEPTH_BITS, 24, + OSMESA_PROFILE, OSMESA_CORE_PROFILE, + OSMESA_CONTEXT_MAJOR_VERSION, 3, + OSMESA_CONTEXT_MINOR_VERSION, 3, + 0 + ]) + self._context = OSMesaCreateContextAttribs(attrs, None) + self._buffer = arrays.GLubyteArray.zeros( + (self.viewport_height, self.viewport_width, 4) + ) + + def make_current(self): + from OpenGL import GL as gl + from OpenGL.osmesa import OSMesaMakeCurrent + assert(OSMesaMakeCurrent( + self._context, self._buffer, gl.GL_UNSIGNED_BYTE, + self.viewport_width, self.viewport_height + )) + + def make_uncurrent(self): + """Make the OpenGL context uncurrent. + """ + pass + + def delete_context(self): + from OpenGL.osmesa import OSMesaDestroyContext + OSMesaDestroyContext(self._context) + self._context = None + self._buffer = None + + def supports_framebuffers(self): + return False diff --git a/pyrender/pyrender/platforms/pyglet_platform.py b/pyrender/pyrender/platforms/pyglet_platform.py new file mode 100644 index 0000000000000000000000000000000000000000..a70cf7b659bc85a92f6c9c8ebcc360662a068507 --- /dev/null +++ b/pyrender/pyrender/platforms/pyglet_platform.py @@ -0,0 +1,90 @@ +from pyrender.constants import (TARGET_OPEN_GL_MAJOR, TARGET_OPEN_GL_MINOR, + MIN_OPEN_GL_MAJOR, MIN_OPEN_GL_MINOR) +from .base import Platform + +import OpenGL + + +__all__ = ['PygletPlatform'] + + +class PygletPlatform(Platform): + """Renders on-screen using a 1x1 hidden Pyglet window for getting + an OpenGL context. + """ + + def __init__(self, viewport_width, viewport_height): + super(PygletPlatform, self).__init__(viewport_width, viewport_height) + self._window = None + + def init_context(self): + import pyglet + pyglet.options['shadow_window'] = False + + try: + pyglet.lib.x11.xlib.XInitThreads() + except Exception: + pass + + self._window = None + confs = [pyglet.gl.Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + pyglet.gl.Config(depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + pyglet.gl.Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR), + pyglet.gl.Config(depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR)] + for conf in confs: + try: + self._window = pyglet.window.Window(config=conf, visible=False, + resizable=False, + width=1, height=1) + break + except pyglet.window.NoSuchConfigException as e: + pass + + if not self._window: + raise ValueError( + 'Failed to initialize Pyglet window with an OpenGL >= 3+ ' + 'context. If you\'re logged in via SSH, ensure that you\'re ' + 'running your script with vglrun (i.e. VirtualGL). The ' + 'internal error message was "{}"'.format(e) + ) + + def make_current(self): + if self._window: + self._window.switch_to() + + def make_uncurrent(self): + try: + import pyglet + pyglet.gl.xlib.glx.glXMakeContextCurrent(self._window.context.x_display, 0, 0, None) + except Exception: + pass + + def delete_context(self): + if self._window is not None: + self.make_current() + cid = OpenGL.contextdata.getContext() + try: + self._window.context.destroy() + self._window.close() + except Exception: + pass + self._window = None + OpenGL.contextdata.cleanupContext(cid) + del cid + + def supports_framebuffers(self): + return True diff --git a/pyrender/pyrender/primitive.py b/pyrender/pyrender/primitive.py new file mode 100644 index 0000000000000000000000000000000000000000..7f83f46f532b126a4573e715dd03d079fef755ca --- /dev/null +++ b/pyrender/pyrender/primitive.py @@ -0,0 +1,489 @@ +"""Primitives, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-primitive + +Author: Matthew Matl +""" +import numpy as np + +from OpenGL.GL import * + +from .material import Material, MetallicRoughnessMaterial +from .constants import FLOAT_SZ, UINT_SZ, BufFlags, GLTF +from .utils import format_color_array + + +class Primitive(object): + """A primitive object which can be rendered. + + Parameters + ---------- + positions : (n, 3) float + XYZ vertex positions. + normals : (n, 3) float + Normalized XYZ vertex normals. + tangents : (n, 4) float + XYZW vertex tangents where the w component is a sign value + (either +1 or -1) indicating the handedness of the tangent basis. + texcoord_0 : (n, 2) float + The first set of UV texture coordinates. + texcoord_1 : (n, 2) float + The second set of UV texture coordinates. + color_0 : (n, 4) float + RGBA vertex colors. + joints_0 : (n, 4) float + Joint information. + weights_0 : (n, 4) float + Weight information for morphing. + indices : (m, 3) int + Face indices for triangle meshes or fans. + material : :class:`Material` + The material to apply to this primitive when rendering. + mode : int + The type of primitives to render, one of the following: + + - ``0``: POINTS + - ``1``: LINES + - ``2``: LINE_LOOP + - ``3``: LINE_STRIP + - ``4``: TRIANGLES + - ``5``: TRIANGLES_STRIP + - ``6``: TRIANGLES_FAN + targets : (k,) int + Morph target indices. + poses : (x,4,4), float + Array of 4x4 transformation matrices for instancing this object. + """ + + def __init__(self, + positions, + normals=None, + tangents=None, + texcoord_0=None, + texcoord_1=None, + color_0=None, + joints_0=None, + weights_0=None, + indices=None, + material=None, + mode=None, + targets=None, + poses=None): + + if mode is None: + mode = GLTF.TRIANGLES + + self.positions = positions + self.normals = normals + self.tangents = tangents + self.texcoord_0 = texcoord_0 + self.texcoord_1 = texcoord_1 + self.color_0 = color_0 + self.joints_0 = joints_0 + self.weights_0 = weights_0 + self.indices = indices + self.material = material + self.mode = mode + self.targets = targets + self.poses = poses + + self._bounds = None + self._vaid = None + self._buffers = [] + self._is_transparent = None + self._buf_flags = None + + @property + def positions(self): + """(n,3) float : XYZ vertex positions. + """ + return self._positions + + @positions.setter + def positions(self, value): + value = np.asanyarray(value, dtype=np.float32) + self._positions = np.ascontiguousarray(value) + self._bounds = None + + @property + def normals(self): + """(n,3) float : Normalized XYZ vertex normals. + """ + return self._normals + + @normals.setter + def normals(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if value.shape != self.positions.shape: + raise ValueError('Incorrect normals shape') + self._normals = value + + @property + def tangents(self): + """(n,4) float : XYZW vertex tangents. + """ + return self._tangents + + @tangents.setter + def tangents(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if value.shape != (self.positions.shape[0], 4): + raise ValueError('Incorrect tangent shape') + self._tangents = value + + @property + def texcoord_0(self): + """(n,2) float : The first set of UV texture coordinates. + """ + return self._texcoord_0 + + @texcoord_0.setter + def texcoord_0(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if (value.ndim != 2 or value.shape[0] != self.positions.shape[0] or + value.shape[1] < 2): + raise ValueError('Incorrect texture coordinate shape') + if value.shape[1] > 2: + value = value[:,:2] + self._texcoord_0 = value + + @property + def texcoord_1(self): + """(n,2) float : The second set of UV texture coordinates. + """ + return self._texcoord_1 + + @texcoord_1.setter + def texcoord_1(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if (value.ndim != 2 or value.shape[0] != self.positions.shape[0] or + value.shape[1] != 2): + raise ValueError('Incorrect texture coordinate shape') + self._texcoord_1 = value + + @property + def color_0(self): + """(n,4) float : RGBA vertex colors. + """ + return self._color_0 + + @color_0.setter + def color_0(self, value): + if value is not None: + value = np.ascontiguousarray( + format_color_array(value, shape=(len(self.positions), 4)) + ) + self._is_transparent = None + self._color_0 = value + + @property + def joints_0(self): + """(n,4) float : Joint information. + """ + return self._joints_0 + + @joints_0.setter + def joints_0(self, value): + self._joints_0 = value + + @property + def weights_0(self): + """(n,4) float : Weight information for morphing. + """ + return self._weights_0 + + @weights_0.setter + def weights_0(self, value): + self._weights_0 = value + + @property + def indices(self): + """(m,3) int : Face indices for triangle meshes or fans. + """ + return self._indices + + @indices.setter + def indices(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + self._indices = value + + @property + def material(self): + """:class:`Material` : The material for this primitive. + """ + return self._material + + @material.setter + def material(self, value): + # Create default material + if value is None: + value = MetallicRoughnessMaterial() + else: + if not isinstance(value, Material): + raise TypeError('Object material must be of type Material') + self._material = value + + @property + def mode(self): + """int : The type of primitive to render. + """ + return self._mode + + @mode.setter + def mode(self, value): + value = int(value) + if value < GLTF.POINTS or value > GLTF.TRIANGLE_FAN: + raise ValueError('Invalid mode') + self._mode = value + + @property + def targets(self): + """(k,) int : Morph target indices. + """ + return self._targets + + @targets.setter + def targets(self, value): + self._targets = value + + @property + def poses(self): + """(x,4,4) float : Homogenous transforms for instancing this primitive. + """ + return self._poses + + @poses.setter + def poses(self, value): + if value is not None: + value = np.asanyarray(value, dtype=np.float32) + value = np.ascontiguousarray(value) + if value.ndim == 2: + value = value[np.newaxis,:,:] + if value.shape[1] != 4 or value.shape[2] != 4: + raise ValueError('Pose matrices must be of shape (n,4,4), ' + 'got {}'.format(value.shape)) + self._poses = value + self._bounds = None + + @property + def bounds(self): + if self._bounds is None: + self._bounds = self._compute_bounds() + return self._bounds + + @property + def centroid(self): + """(3,) float : The centroid of the primitive's AABB. + """ + return np.mean(self.bounds, axis=0) + + @property + def extents(self): + """(3,) float : The lengths of the axes of the primitive's AABB. + """ + return np.diff(self.bounds, axis=0).reshape(-1) + + @property + def scale(self): + """(3,) float : The length of the diagonal of the primitive's AABB. + """ + return np.linalg.norm(self.extents) + + @property + def buf_flags(self): + """int : The flags for the render buffer. + """ + if self._buf_flags is None: + self._buf_flags = self._compute_buf_flags() + return self._buf_flags + + def delete(self): + self._unbind() + self._remove_from_context() + + @property + def is_transparent(self): + """bool : If True, the mesh is partially-transparent. + """ + return self._compute_transparency() + + def _add_to_context(self): + if self._vaid is not None: + raise ValueError('Mesh is already bound to a context') + + # Generate and bind VAO + self._vaid = glGenVertexArrays(1) + glBindVertexArray(self._vaid) + + ####################################################################### + # Fill vertex buffer + ####################################################################### + + # Generate and bind vertex buffer + vertexbuffer = glGenBuffers(1) + self._buffers.append(vertexbuffer) + glBindBuffer(GL_ARRAY_BUFFER, vertexbuffer) + + # positions + vertex_data = self.positions + attr_sizes = [3] + + # Normals + if self.normals is not None: + vertex_data = np.hstack((vertex_data, self.normals)) + attr_sizes.append(3) + + # Tangents + if self.tangents is not None: + vertex_data = np.hstack((vertex_data, self.tangents)) + attr_sizes.append(4) + + # Texture Coordinates + if self.texcoord_0 is not None: + vertex_data = np.hstack((vertex_data, self.texcoord_0)) + attr_sizes.append(2) + if self.texcoord_1 is not None: + vertex_data = np.hstack((vertex_data, self.texcoord_1)) + attr_sizes.append(2) + + # Color + if self.color_0 is not None: + vertex_data = np.hstack((vertex_data, self.color_0)) + attr_sizes.append(4) + + # TODO JOINTS AND WEIGHTS + # PASS + + # Copy data to buffer + vertex_data = np.ascontiguousarray( + vertex_data.flatten().astype(np.float32) + ) + glBufferData( + GL_ARRAY_BUFFER, FLOAT_SZ * len(vertex_data), + vertex_data, GL_STATIC_DRAW + ) + total_sz = sum(attr_sizes) + offset = 0 + for i, sz in enumerate(attr_sizes): + glVertexAttribPointer( + i, sz, GL_FLOAT, GL_FALSE, FLOAT_SZ * total_sz, + ctypes.c_void_p(FLOAT_SZ * offset) + ) + glEnableVertexAttribArray(i) + offset += sz + + ####################################################################### + # Fill model matrix buffer + ####################################################################### + + if self.poses is not None: + pose_data = np.ascontiguousarray( + np.transpose(self.poses, [0,2,1]).flatten().astype(np.float32) + ) + else: + pose_data = np.ascontiguousarray( + np.eye(4).flatten().astype(np.float32) + ) + + modelbuffer = glGenBuffers(1) + self._buffers.append(modelbuffer) + glBindBuffer(GL_ARRAY_BUFFER, modelbuffer) + glBufferData( + GL_ARRAY_BUFFER, FLOAT_SZ * len(pose_data), + pose_data, GL_STATIC_DRAW + ) + + for i in range(0, 4): + idx = i + len(attr_sizes) + glEnableVertexAttribArray(idx) + glVertexAttribPointer( + idx, 4, GL_FLOAT, GL_FALSE, FLOAT_SZ * 4 * 4, + ctypes.c_void_p(4 * FLOAT_SZ * i) + ) + glVertexAttribDivisor(idx, 1) + + ####################################################################### + # Fill element buffer + ####################################################################### + if self.indices is not None: + elementbuffer = glGenBuffers(1) + self._buffers.append(elementbuffer) + glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, elementbuffer) + glBufferData(GL_ELEMENT_ARRAY_BUFFER, UINT_SZ * self.indices.size, + self.indices.flatten().astype(np.uint32), + GL_STATIC_DRAW) + + glBindVertexArray(0) + + def _remove_from_context(self): + if self._vaid is not None: + glDeleteVertexArrays(1, [self._vaid]) + glDeleteBuffers(len(self._buffers), self._buffers) + self._vaid = None + self._buffers = [] + + def _in_context(self): + return self._vaid is not None + + def _bind(self): + if self._vaid is None: + raise ValueError('Cannot bind a Mesh that has not been added ' + 'to a context') + glBindVertexArray(self._vaid) + + def _unbind(self): + glBindVertexArray(0) + + def _compute_bounds(self): + """Compute the bounds of this object. + """ + # Compute bounds of this object + bounds = np.array([np.min(self.positions, axis=0), + np.max(self.positions, axis=0)]) + + # If instanced, compute translations for approximate bounds + if self.poses is not None: + bounds += np.array([np.min(self.poses[:,:3,3], axis=0), + np.max(self.poses[:,:3,3], axis=0)]) + return bounds + + def _compute_transparency(self): + """Compute whether or not this object is transparent. + """ + if self.material.is_transparent: + return True + if self._is_transparent is None: + self._is_transparent = False + if self.color_0 is not None: + if np.any(self._color_0[:,3] != 1.0): + self._is_transparent = True + return self._is_transparent + + def _compute_buf_flags(self): + buf_flags = BufFlags.POSITION + + if self.normals is not None: + buf_flags |= BufFlags.NORMAL + if self.tangents is not None: + buf_flags |= BufFlags.TANGENT + if self.texcoord_0 is not None: + buf_flags |= BufFlags.TEXCOORD_0 + if self.texcoord_1 is not None: + buf_flags |= BufFlags.TEXCOORD_1 + if self.color_0 is not None: + buf_flags |= BufFlags.COLOR_0 + if self.joints_0 is not None: + buf_flags |= BufFlags.JOINTS_0 + if self.weights_0 is not None: + buf_flags |= BufFlags.WEIGHTS_0 + + return buf_flags diff --git a/pyrender/pyrender/renderer.py b/pyrender/pyrender/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae14c5cdb1785226a52ae6b71b08f01de069962 --- /dev/null +++ b/pyrender/pyrender/renderer.py @@ -0,0 +1,1339 @@ +"""PBR renderer for Python. + +Author: Matthew Matl +""" +import sys + +import numpy as np +import PIL + +from .constants import (RenderFlags, TextAlign, GLTF, BufFlags, TexFlags, + ProgramFlags, DEFAULT_Z_FAR, DEFAULT_Z_NEAR, + SHADOW_TEX_SZ, MAX_N_LIGHTS) +from .shader_program import ShaderProgramCache +from .material import MetallicRoughnessMaterial, SpecularGlossinessMaterial +from .light import PointLight, SpotLight, DirectionalLight +from .font import FontCache +from .utils import format_color_vector + +from OpenGL.GL import * + + +class Renderer(object): + """Class for handling all rendering operations on a scene. + + Note + ---- + This renderer relies on the existence of an OpenGL context and + does not create one on its own. + + Parameters + ---------- + viewport_width : int + Width of the viewport in pixels. + viewport_height : int + Width of the viewport height in pixels. + point_size : float, optional + Size of points in pixels. Defaults to 1.0. + """ + + def __init__(self, viewport_width, viewport_height, point_size=1.0): + self.dpscale = 1 + # Scaling needed on retina displays + if sys.platform == 'darwin': + self.dpscale = 2 + + self.viewport_width = viewport_width + self.viewport_height = viewport_height + self.point_size = point_size + + # Optional framebuffer for offscreen renders + self._main_fb = None + self._main_cb = None + self._main_db = None + self._main_fb_ms = None + self._main_cb_ms = None + self._main_db_ms = None + self._main_fb_dims = (None, None) + self._shadow_fb = None + self._latest_znear = DEFAULT_Z_NEAR + self._latest_zfar = DEFAULT_Z_FAR + + # Shader Program Cache + self._program_cache = ShaderProgramCache() + self._font_cache = FontCache() + self._meshes = set() + self._mesh_textures = set() + self._shadow_textures = set() + self._texture_alloc_idx = 0 + + @property + def viewport_width(self): + """int : The width of the main viewport, in pixels. + """ + return self._viewport_width + + @viewport_width.setter + def viewport_width(self, value): + self._viewport_width = self.dpscale * value + + @property + def viewport_height(self): + """int : The height of the main viewport, in pixels. + """ + return self._viewport_height + + @viewport_height.setter + def viewport_height(self, value): + self._viewport_height = self.dpscale * value + + @property + def point_size(self): + """float : The size of screen-space points, in pixels. + """ + return self._point_size + + @point_size.setter + def point_size(self, value): + self._point_size = float(value) + + def render(self, scene, flags, seg_node_map=None): + """Render a scene with the given set of flags. + + Parameters + ---------- + scene : :class:`Scene` + A scene to render. + flags : int + A specification from :class:`.RenderFlags`. + seg_node_map : dict + A map from :class:`.Node` objects to (3,) colors for each. + If specified along with flags set to :attr:`.RenderFlags.SEG`, + the color image will be a segmentation image. + + Returns + ------- + color_im : (h, w, 3) uint8 or (h, w, 4) uint8 + If :attr:`RenderFlags.OFFSCREEN` is set, the color buffer. This is + normally an RGB buffer, but if :attr:`.RenderFlags.RGBA` is set, + the buffer will be a full RGBA buffer. + depth_im : (h, w) float32 + If :attr:`RenderFlags.OFFSCREEN` is set, the depth buffer + in linear units. + """ + # Update context with meshes and textures + self._update_context(scene, flags) + + # Render necessary shadow maps + if not bool(flags & RenderFlags.DEPTH_ONLY or flags & RenderFlags.SEG): + for ln in scene.light_nodes: + take_pass = False + if (isinstance(ln.light, DirectionalLight) and + bool(flags & RenderFlags.SHADOWS_DIRECTIONAL)): + take_pass = True + elif (isinstance(ln.light, SpotLight) and + bool(flags & RenderFlags.SHADOWS_SPOT)): + take_pass = True + elif (isinstance(ln.light, PointLight) and + bool(flags & RenderFlags.SHADOWS_POINT)): + take_pass = True + if take_pass: + self._shadow_mapping_pass(scene, ln, flags) + + # Make forward pass + retval = self._forward_pass(scene, flags, seg_node_map=seg_node_map) + + # If necessary, make normals pass + if flags & (RenderFlags.VERTEX_NORMALS | RenderFlags.FACE_NORMALS): + self._normals_pass(scene, flags) + + # Update camera settings for retrieving depth buffers + self._latest_znear = scene.main_camera_node.camera.znear + self._latest_zfar = scene.main_camera_node.camera.zfar + + return retval + + def render_text(self, text, x, y, font_name='OpenSans-Regular', + font_pt=40, color=None, scale=1.0, + align=TextAlign.BOTTOM_LEFT): + """Render text into the current viewport. + + Note + ---- + This cannot be done into an offscreen buffer. + + Parameters + ---------- + text : str + The text to render. + x : int + Horizontal pixel location of text. + y : int + Vertical pixel location of text. + font_name : str + Name of font, from the ``pyrender/fonts`` folder, or + a path to a ``.ttf`` file. + font_pt : int + Height of the text, in font points. + color : (4,) float + The color of the text. Default is black. + scale : int + Scaling factor for text. + align : int + One of the :class:`TextAlign` options which specifies where the + ``x`` and ``y`` parameters lie on the text. For example, + :attr:`TextAlign.BOTTOM_LEFT` means that ``x`` and ``y`` indicate + the position of the bottom-left corner of the textbox. + """ + x *= self.dpscale + y *= self.dpscale + font_pt *= self.dpscale + + if color is None: + color = np.array([0.0, 0.0, 0.0, 1.0]) + else: + color = format_color_vector(color, 4) + + # Set up viewport for render + self._configure_forward_pass_viewport(0) + + # Load font + font = self._font_cache.get_font(font_name, font_pt) + if not font._in_context(): + font._add_to_context() + + # Load program + program = self._get_text_program() + program._bind() + + # Set uniforms + p = np.eye(4) + p[0,0] = 2.0 / self.viewport_width + p[0,3] = -1.0 + p[1,1] = 2.0 / self.viewport_height + p[1,3] = -1.0 + program.set_uniform('projection', p) + program.set_uniform('text_color', color) + + # Draw text + font.render_string(text, x, y, scale, align) + + def read_color_buf(self): + """Read and return the current viewport's color buffer. + + Alpha cannot be computed for an on-screen buffer. + + Returns + ------- + color_im : (h, w, 3) uint8 + The color buffer in RGB byte format. + """ + # Extract color image from frame buffer + width, height = self.viewport_width, self.viewport_height + glBindFramebuffer(GL_READ_FRAMEBUFFER, 0) + glReadBuffer(GL_FRONT) + color_buf = glReadPixels(0, 0, width, height, GL_RGB, GL_UNSIGNED_BYTE) + + # Re-format them into numpy arrays + color_im = np.frombuffer(color_buf, dtype=np.uint8) + color_im = color_im.reshape((height, width, 3)) + color_im = np.flip(color_im, axis=0) + + # Resize for macos if needed + if sys.platform == 'darwin': + color_im = self._resize_image(color_im, True) + + return color_im + + def read_depth_buf(self): + """Read and return the current viewport's color buffer. + + Returns + ------- + depth_im : (h, w) float32 + The depth buffer in linear units. + """ + width, height = self.viewport_width, self.viewport_height + glBindFramebuffer(GL_READ_FRAMEBUFFER, 0) + glReadBuffer(GL_FRONT) + depth_buf = glReadPixels( + 0, 0, width, height, GL_DEPTH_COMPONENT, GL_FLOAT + ) + + depth_im = np.frombuffer(depth_buf, dtype=np.float32) + depth_im = depth_im.reshape((height, width)) + depth_im = np.flip(depth_im, axis=0) + + inf_inds = (depth_im == 1.0) + depth_im = 2.0 * depth_im - 1.0 + z_near, z_far = self._latest_znear, self._latest_zfar + noninf = np.logical_not(inf_inds) + if z_far is None: + depth_im[noninf] = 2 * z_near / (1.0 - depth_im[noninf]) + else: + depth_im[noninf] = ((2.0 * z_near * z_far) / + (z_far + z_near - depth_im[noninf] * + (z_far - z_near))) + depth_im[inf_inds] = 0.0 + + # Resize for macos if needed + if sys.platform == 'darwin': + depth_im = self._resize_image(depth_im) + + return depth_im + + def delete(self): + """Free all allocated OpenGL resources. + """ + # Free shaders + self._program_cache.clear() + + # Free fonts + self._font_cache.clear() + + # Free meshes + for mesh in self._meshes: + for p in mesh.primitives: + p.delete() + + # Free textures + for mesh_texture in self._mesh_textures: + mesh_texture.delete() + + for shadow_texture in self._shadow_textures: + shadow_texture.delete() + + self._meshes = set() + self._mesh_textures = set() + self._shadow_textures = set() + self._texture_alloc_idx = 0 + + self._delete_main_framebuffer() + self._delete_shadow_framebuffer() + + def __del__(self): + try: + self.delete() + except Exception: + pass + + ########################################################################### + # Rendering passes + ########################################################################### + + def _forward_pass(self, scene, flags, seg_node_map=None): + # Set up viewport for render + self._configure_forward_pass_viewport(flags) + + # Clear it + if bool(flags & RenderFlags.SEG): + glClearColor(0.0, 0.0, 0.0, 1.0) + if seg_node_map is None: + seg_node_map = {} + else: + glClearColor(*scene.bg_color) + + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + + if not bool(flags & RenderFlags.SEG): + glEnable(GL_MULTISAMPLE) + else: + glDisable(GL_MULTISAMPLE) + + # Set up camera matrices + V, P = self._get_camera_matrices(scene) + + program = None + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + # If SEG, set color + if bool(flags & RenderFlags.SEG): + if node not in seg_node_map: + continue + color = seg_node_map[node] + if not isinstance(color, (list, tuple, np.ndarray)): + color = np.repeat(color, 3) + else: + color = np.asanyarray(color) + color = color / 255.0 + + for primitive in mesh.primitives: + + # First, get and bind the appropriate program + program = self._get_primitive_program( + primitive, flags, ProgramFlags.USE_MATERIAL + ) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform( + 'cam_pos', scene.get_pose(scene.main_camera_node)[:3,3] + ) + if bool(flags & RenderFlags.SEG): + program.set_uniform('color', color) + + # Next, bind the lighting + if not (flags & RenderFlags.DEPTH_ONLY or flags & RenderFlags.FLAT or + flags & RenderFlags.SEG): + self._bind_lighting(scene, program, node, flags) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=flags + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + # If doing offscreen render, copy result from framebuffer and return + if flags & RenderFlags.OFFSCREEN: + return self._read_main_framebuffer(scene, flags) + else: + return + + def _shadow_mapping_pass(self, scene, light_node, flags): + light = light_node.light + + # Set up viewport for render + self._configure_shadow_mapping_viewport(light, flags) + + # Set up camera matrices + V, P = self._get_light_cam_matrices(scene, light_node, flags) + + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + for primitive in mesh.primitives: + + # First, get and bind the appropriate program + program = self._get_primitive_program( + primitive, flags, ProgramFlags.NONE + ) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform( + 'cam_pos', scene.get_pose(scene.main_camera_node)[:3,3] + ) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=RenderFlags.DEPTH_ONLY + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + def _normals_pass(self, scene, flags): + # Set up viewport for render + self._configure_forward_pass_viewport(flags) + program = None + + # Set up camera matrices + V, P = self._get_camera_matrices(scene) + + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + for primitive in mesh.primitives: + + # Skip objects that don't have normals + if not primitive.buf_flags & BufFlags.NORMAL: + continue + + # First, get and bind the appropriate program + pf = ProgramFlags.NONE + if flags & RenderFlags.VERTEX_NORMALS: + pf = pf | ProgramFlags.VERTEX_NORMALS + if flags & RenderFlags.FACE_NORMALS: + pf = pf | ProgramFlags.FACE_NORMALS + program = self._get_primitive_program(primitive, flags, pf) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform('normal_magnitude', 0.05 * primitive.scale) + program.set_uniform( + 'normal_color', np.array([0.1, 0.1, 1.0, 1.0]) + ) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=RenderFlags.DEPTH_ONLY + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + ########################################################################### + # Handlers for binding uniforms and drawing primitives + ########################################################################### + + def _bind_and_draw_primitive(self, primitive, pose, program, flags): + # Set model pose matrix + program.set_uniform('M', pose) + + # Bind mesh buffers + primitive._bind() + + # Bind mesh material + if not (flags & RenderFlags.DEPTH_ONLY or flags & RenderFlags.SEG): + material = primitive.material + + # Bind textures + tf = material.tex_flags + if tf & TexFlags.NORMAL: + self._bind_texture(material.normalTexture, + 'material.normal_texture', program) + if tf & TexFlags.OCCLUSION: + self._bind_texture(material.occlusionTexture, + 'material.occlusion_texture', program) + if tf & TexFlags.EMISSIVE: + self._bind_texture(material.emissiveTexture, + 'material.emissive_texture', program) + if tf & TexFlags.BASE_COLOR: + self._bind_texture(material.baseColorTexture, + 'material.base_color_texture', program) + if tf & TexFlags.METALLIC_ROUGHNESS: + self._bind_texture(material.metallicRoughnessTexture, + 'material.metallic_roughness_texture', + program) + if tf & TexFlags.DIFFUSE: + self._bind_texture(material.diffuseTexture, + 'material.diffuse_texture', program) + if tf & TexFlags.SPECULAR_GLOSSINESS: + self._bind_texture(material.specularGlossinessTexture, + 'material.specular_glossiness_texture', + program) + + # Bind other uniforms + b = 'material.{}' + program.set_uniform(b.format('emissive_factor'), + material.emissiveFactor) + if isinstance(material, MetallicRoughnessMaterial): + program.set_uniform(b.format('base_color_factor'), + material.baseColorFactor) + program.set_uniform(b.format('metallic_factor'), + material.metallicFactor) + program.set_uniform(b.format('roughness_factor'), + material.roughnessFactor) + elif isinstance(material, SpecularGlossinessMaterial): + program.set_uniform(b.format('diffuse_factor'), + material.diffuseFactor) + program.set_uniform(b.format('specular_factor'), + material.specularFactor) + program.set_uniform(b.format('glossiness_factor'), + material.glossinessFactor) + + # Set blending options + if material.alphaMode == 'BLEND': + glEnable(GL_BLEND) + glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + else: + glEnable(GL_BLEND) + glBlendFunc(GL_ONE, GL_ZERO) + + # Set wireframe mode + wf = material.wireframe + if flags & RenderFlags.FLIP_WIREFRAME: + wf = not wf + if (flags & RenderFlags.ALL_WIREFRAME) or wf: + glPolygonMode(GL_FRONT_AND_BACK, GL_LINE) + else: + glPolygonMode(GL_FRONT_AND_BACK, GL_FILL) + + # Set culling mode + if material.doubleSided or flags & RenderFlags.SKIP_CULL_FACES: + glDisable(GL_CULL_FACE) + else: + glEnable(GL_CULL_FACE) + glCullFace(GL_BACK) + else: + glEnable(GL_CULL_FACE) + glEnable(GL_BLEND) + glCullFace(GL_BACK) + glBlendFunc(GL_ONE, GL_ZERO) + glPolygonMode(GL_FRONT_AND_BACK, GL_FILL) + + # Set point size if needed + glDisable(GL_PROGRAM_POINT_SIZE) + if primitive.mode == GLTF.POINTS: + glEnable(GL_PROGRAM_POINT_SIZE) + glPointSize(self.point_size) + + # Render mesh + n_instances = 1 + if primitive.poses is not None: + n_instances = len(primitive.poses) + + if primitive.indices is not None: + glDrawElementsInstanced( + primitive.mode, primitive.indices.size, GL_UNSIGNED_INT, + ctypes.c_void_p(0), n_instances + ) + else: + glDrawArraysInstanced( + primitive.mode, 0, len(primitive.positions), n_instances + ) + + # Unbind mesh buffers + primitive._unbind() + + def _bind_lighting(self, scene, program, node, flags): + """Bind all lighting uniform values for a scene. + """ + max_n_lights = self._compute_max_n_lights(flags) + + n_d = min(len(scene.directional_light_nodes), max_n_lights[0]) + n_s = min(len(scene.spot_light_nodes), max_n_lights[1]) + n_p = min(len(scene.point_light_nodes), max_n_lights[2]) + program.set_uniform('ambient_light', scene.ambient_light) + program.set_uniform('n_directional_lights', n_d) + program.set_uniform('n_spot_lights', n_s) + program.set_uniform('n_point_lights', n_p) + plc = 0 + slc = 0 + dlc = 0 + + light_nodes = scene.light_nodes + if (len(scene.directional_light_nodes) > max_n_lights[0] or + len(scene.spot_light_nodes) > max_n_lights[1] or + len(scene.point_light_nodes) > max_n_lights[2]): + light_nodes = self._sorted_nodes_by_distance( + scene, scene.light_nodes, node + ) + + for n in light_nodes: + light = n.light + pose = scene.get_pose(n) + position = pose[:3,3] + direction = -pose[:3,2] + + if isinstance(light, PointLight): + if plc == max_n_lights[2]: + continue + b = 'point_lights[{}].'.format(plc) + plc += 1 + shadow = bool(flags & RenderFlags.SHADOWS_POINT) + program.set_uniform(b + 'position', position) + elif isinstance(light, SpotLight): + if slc == max_n_lights[1]: + continue + b = 'spot_lights[{}].'.format(slc) + slc += 1 + shadow = bool(flags & RenderFlags.SHADOWS_SPOT) + las = 1.0 / max(0.001, np.cos(light.innerConeAngle) - + np.cos(light.outerConeAngle)) + lao = -np.cos(light.outerConeAngle) * las + program.set_uniform(b + 'direction', direction) + program.set_uniform(b + 'position', position) + program.set_uniform(b + 'light_angle_scale', las) + program.set_uniform(b + 'light_angle_offset', lao) + else: + if dlc == max_n_lights[0]: + continue + b = 'directional_lights[{}].'.format(dlc) + dlc += 1 + shadow = bool(flags & RenderFlags.SHADOWS_DIRECTIONAL) + program.set_uniform(b + 'direction', direction) + + program.set_uniform(b + 'color', light.color) + program.set_uniform(b + 'intensity', light.intensity) + # if light.range is not None: + # program.set_uniform(b + 'range', light.range) + # else: + # program.set_uniform(b + 'range', 0) + + if shadow: + self._bind_texture(light.shadow_texture, + b + 'shadow_map', program) + if not isinstance(light, PointLight): + V, P = self._get_light_cam_matrices(scene, n, flags) + program.set_uniform(b + 'light_matrix', P.dot(V)) + else: + raise NotImplementedError( + 'Point light shadows not implemented' + ) + + def _sorted_mesh_nodes(self, scene): + cam_loc = scene.get_pose(scene.main_camera_node)[:3,3] + solid_nodes = [] + trans_nodes = [] + for node in scene.mesh_nodes: + mesh = node.mesh + if mesh.is_transparent: + trans_nodes.append(node) + else: + solid_nodes.append(node) + + # TODO BETTER SORTING METHOD + trans_nodes.sort( + key=lambda n: -np.linalg.norm(scene.get_pose(n)[:3,3] - cam_loc) + ) + solid_nodes.sort( + key=lambda n: -np.linalg.norm(scene.get_pose(n)[:3,3] - cam_loc) + ) + + return solid_nodes + trans_nodes + + def _sorted_nodes_by_distance(self, scene, nodes, compare_node): + nodes = list(nodes) + compare_posn = scene.get_pose(compare_node)[:3,3] + nodes.sort(key=lambda n: np.linalg.norm( + scene.get_pose(n)[:3,3] - compare_posn) + ) + return nodes + + ########################################################################### + # Context Management + ########################################################################### + + def _update_context(self, scene, flags): + + # Update meshes + scene_meshes = scene.meshes + + # Add new meshes to context + for mesh in scene_meshes - self._meshes: + for p in mesh.primitives: + p._add_to_context() + + # Remove old meshes from context + for mesh in self._meshes - scene_meshes: + for p in mesh.primitives: + p.delete() + + self._meshes = scene_meshes.copy() + + # Update mesh textures + mesh_textures = set() + for m in scene_meshes: + for p in m.primitives: + mesh_textures |= p.material.textures + + # Add new textures to context + for texture in mesh_textures - self._mesh_textures: + texture._add_to_context() + + # Remove old textures from context + for texture in self._mesh_textures - mesh_textures: + texture.delete() + + self._mesh_textures = mesh_textures.copy() + + shadow_textures = set() + for l in scene.lights: + # Create if needed + active = False + if (isinstance(l, DirectionalLight) and + flags & RenderFlags.SHADOWS_DIRECTIONAL): + active = True + elif (isinstance(l, PointLight) and + flags & RenderFlags.SHADOWS_POINT): + active = True + elif isinstance(l, SpotLight) and flags & RenderFlags.SHADOWS_SPOT: + active = True + + if active and l.shadow_texture is None: + l._generate_shadow_texture() + if l.shadow_texture is not None: + shadow_textures.add(l.shadow_texture) + + # Add new textures to context + for texture in shadow_textures - self._shadow_textures: + texture._add_to_context() + + # Remove old textures from context + for texture in self._shadow_textures - shadow_textures: + texture.delete() + + self._shadow_textures = shadow_textures.copy() + + ########################################################################### + # Texture Management + ########################################################################### + + def _bind_texture(self, texture, uniform_name, program): + """Bind a texture to an active texture unit and return + the texture unit index that was used. + """ + tex_id = self._get_next_active_texture() + glActiveTexture(GL_TEXTURE0 + tex_id) + texture._bind() + program.set_uniform(uniform_name, tex_id) + + def _get_next_active_texture(self): + val = self._texture_alloc_idx + self._texture_alloc_idx += 1 + return val + + def _reset_active_textures(self): + self._texture_alloc_idx = 0 + + ########################################################################### + # Camera Matrix Management + ########################################################################### + + def _get_camera_matrices(self, scene): + main_camera_node = scene.main_camera_node + if main_camera_node is None: + raise ValueError('Cannot render scene without a camera') + P = main_camera_node.camera.get_projection_matrix( + width=self.viewport_width, height=self.viewport_height + ) + pose = scene.get_pose(main_camera_node) + V = np.linalg.inv(pose) # V maps from world to camera + return V, P + + def _get_light_cam_matrices(self, scene, light_node, flags): + light = light_node.light + pose = scene.get_pose(light_node).copy() + s = scene.scale + camera = light._get_shadow_camera(s) + P = camera.get_projection_matrix() + if isinstance(light, DirectionalLight): + direction = -pose[:3,2] + c = scene.centroid + loc = c - direction * s + pose[:3,3] = loc + V = np.linalg.inv(pose) # V maps from world to camera + return V, P + + ########################################################################### + # Shader Program Management + ########################################################################### + + def _get_text_program(self): + program = self._program_cache.get_program( + vertex_shader='text.vert', + fragment_shader='text.frag' + ) + + if not program._in_context(): + program._add_to_context() + + return program + + def _compute_max_n_lights(self, flags): + max_n_lights = [MAX_N_LIGHTS, MAX_N_LIGHTS, MAX_N_LIGHTS] + n_tex_units = glGetIntegerv(GL_MAX_TEXTURE_IMAGE_UNITS) + + # Reserved texture units: 6 + # Normal Map + # Occlusion Map + # Emissive Map + # Base Color or Diffuse Map + # MR or SG Map + # Environment cubemap + + n_reserved_textures = 6 + n_available_textures = n_tex_units - n_reserved_textures + + # Distribute textures evenly among lights with shadows, with + # a preference for directional lights + n_shadow_types = 0 + if flags & RenderFlags.SHADOWS_DIRECTIONAL: + n_shadow_types += 1 + if flags & RenderFlags.SHADOWS_SPOT: + n_shadow_types += 1 + if flags & RenderFlags.SHADOWS_POINT: + n_shadow_types += 1 + + if n_shadow_types > 0: + tex_per_light = n_available_textures // n_shadow_types + + if flags & RenderFlags.SHADOWS_DIRECTIONAL: + max_n_lights[0] = ( + tex_per_light + + (n_available_textures - tex_per_light * n_shadow_types) + ) + if flags & RenderFlags.SHADOWS_SPOT: + max_n_lights[1] = tex_per_light + if flags & RenderFlags.SHADOWS_POINT: + max_n_lights[2] = tex_per_light + + return max_n_lights + + def _get_primitive_program(self, primitive, flags, program_flags): + vertex_shader = None + fragment_shader = None + geometry_shader = None + defines = {} + + if (bool(program_flags & ProgramFlags.USE_MATERIAL) and + not flags & RenderFlags.DEPTH_ONLY and + not flags & RenderFlags.FLAT and + not flags & RenderFlags.SEG): + vertex_shader = 'mesh.vert' + fragment_shader = 'mesh.frag' + elif bool(program_flags & (ProgramFlags.VERTEX_NORMALS | + ProgramFlags.FACE_NORMALS)): + vertex_shader = 'vertex_normals.vert' + if primitive.mode == GLTF.POINTS: + geometry_shader = 'vertex_normals_pc.geom' + else: + geometry_shader = 'vertex_normals.geom' + fragment_shader = 'vertex_normals.frag' + elif flags & RenderFlags.FLAT: + vertex_shader = 'flat.vert' + fragment_shader = 'flat.frag' + elif flags & RenderFlags.SEG: + vertex_shader = 'segmentation.vert' + fragment_shader = 'segmentation.frag' + else: + vertex_shader = 'mesh_depth.vert' + fragment_shader = 'mesh_depth.frag' + + # Set up vertex buffer DEFINES + bf = primitive.buf_flags + buf_idx = 1 + if bf & BufFlags.NORMAL: + defines['NORMAL_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.TANGENT: + defines['TANGENT_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.TEXCOORD_0: + defines['TEXCOORD_0_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.TEXCOORD_1: + defines['TEXCOORD_1_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.COLOR_0: + defines['COLOR_0_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.JOINTS_0: + defines['JOINTS_0_LOC'] = buf_idx + buf_idx += 1 + if bf & BufFlags.WEIGHTS_0: + defines['WEIGHTS_0_LOC'] = buf_idx + buf_idx += 1 + defines['INST_M_LOC'] = buf_idx + + # Set up shadow mapping defines + if flags & RenderFlags.SHADOWS_DIRECTIONAL: + defines['DIRECTIONAL_LIGHT_SHADOWS'] = 1 + if flags & RenderFlags.SHADOWS_SPOT: + defines['SPOT_LIGHT_SHADOWS'] = 1 + if flags & RenderFlags.SHADOWS_POINT: + defines['POINT_LIGHT_SHADOWS'] = 1 + max_n_lights = self._compute_max_n_lights(flags) + defines['MAX_DIRECTIONAL_LIGHTS'] = max_n_lights[0] + defines['MAX_SPOT_LIGHTS'] = max_n_lights[1] + defines['MAX_POINT_LIGHTS'] = max_n_lights[2] + + # Set up vertex normal defines + if program_flags & ProgramFlags.VERTEX_NORMALS: + defines['VERTEX_NORMALS'] = 1 + if program_flags & ProgramFlags.FACE_NORMALS: + defines['FACE_NORMALS'] = 1 + + # Set up material texture defines + if bool(program_flags & ProgramFlags.USE_MATERIAL): + tf = primitive.material.tex_flags + if tf & TexFlags.NORMAL: + defines['HAS_NORMAL_TEX'] = 1 + if tf & TexFlags.OCCLUSION: + defines['HAS_OCCLUSION_TEX'] = 1 + if tf & TexFlags.EMISSIVE: + defines['HAS_EMISSIVE_TEX'] = 1 + if tf & TexFlags.BASE_COLOR: + defines['HAS_BASE_COLOR_TEX'] = 1 + if tf & TexFlags.METALLIC_ROUGHNESS: + defines['HAS_METALLIC_ROUGHNESS_TEX'] = 1 + if tf & TexFlags.DIFFUSE: + defines['HAS_DIFFUSE_TEX'] = 1 + if tf & TexFlags.SPECULAR_GLOSSINESS: + defines['HAS_SPECULAR_GLOSSINESS_TEX'] = 1 + if isinstance(primitive.material, MetallicRoughnessMaterial): + defines['USE_METALLIC_MATERIAL'] = 1 + elif isinstance(primitive.material, SpecularGlossinessMaterial): + defines['USE_GLOSSY_MATERIAL'] = 1 + + program = self._program_cache.get_program( + vertex_shader=vertex_shader, + fragment_shader=fragment_shader, + geometry_shader=geometry_shader, + defines=defines + ) + + if not program._in_context(): + program._add_to_context() + + return program + + ########################################################################### + # Viewport Management + ########################################################################### + + def _configure_forward_pass_viewport(self, flags): + + # If using offscreen render, bind main framebuffer + if flags & RenderFlags.OFFSCREEN: + self._configure_main_framebuffer() + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb_ms) + else: + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, 0) + + glViewport(0, 0, self.viewport_width, self.viewport_height) + glEnable(GL_DEPTH_TEST) + glDepthMask(GL_TRUE) + glDepthFunc(GL_LESS) + glDepthRange(0.0, 1.0) + + def _configure_shadow_mapping_viewport(self, light, flags): + self._configure_shadow_framebuffer() + glBindFramebuffer(GL_FRAMEBUFFER, self._shadow_fb) + light.shadow_texture._bind() + light.shadow_texture._bind_as_depth_attachment() + glActiveTexture(GL_TEXTURE0) + light.shadow_texture._bind() + glDrawBuffer(GL_NONE) + glReadBuffer(GL_NONE) + + glClear(GL_DEPTH_BUFFER_BIT) + glViewport(0, 0, SHADOW_TEX_SZ, SHADOW_TEX_SZ) + glEnable(GL_DEPTH_TEST) + glDepthMask(GL_TRUE) + glDepthFunc(GL_LESS) + glDepthRange(0.0, 1.0) + glDisable(GL_CULL_FACE) + glDisable(GL_BLEND) + + ########################################################################### + # Framebuffer Management + ########################################################################### + + def _configure_shadow_framebuffer(self): + if self._shadow_fb is None: + self._shadow_fb = glGenFramebuffers(1) + + def _delete_shadow_framebuffer(self): + if self._shadow_fb is not None: + glDeleteFramebuffers(1, [self._shadow_fb]) + + def _configure_main_framebuffer(self): + # If mismatch with prior framebuffer, delete it + if (self._main_fb is not None and + self.viewport_width != self._main_fb_dims[0] or + self.viewport_height != self._main_fb_dims[1]): + self._delete_main_framebuffer() + + # If framebuffer doesn't exist, create it + if self._main_fb is None: + # Generate standard buffer + self._main_cb, self._main_db = glGenRenderbuffers(2) + + glBindRenderbuffer(GL_RENDERBUFFER, self._main_cb) + glRenderbufferStorage( + GL_RENDERBUFFER, GL_RGBA, + self.viewport_width, self.viewport_height + ) + + glBindRenderbuffer(GL_RENDERBUFFER, self._main_db) + glRenderbufferStorage( + GL_RENDERBUFFER, GL_DEPTH_COMPONENT24, + self.viewport_width, self.viewport_height + ) + + self._main_fb = glGenFramebuffers(1) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_RENDERBUFFER, self._main_cb + ) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + GL_RENDERBUFFER, self._main_db + ) + + # Generate multisample buffer + self._main_cb_ms, self._main_db_ms = glGenRenderbuffers(2) + glBindRenderbuffer(GL_RENDERBUFFER, self._main_cb_ms) + # glRenderbufferStorageMultisample( + # GL_RENDERBUFFER, 4, GL_RGBA, + # self.viewport_width, self.viewport_height + # ) + # glBindRenderbuffer(GL_RENDERBUFFER, self._main_db_ms) + # glRenderbufferStorageMultisample( + # GL_RENDERBUFFER, 4, GL_DEPTH_COMPONENT24, + # self.viewport_width, self.viewport_height + # ) + # 增加这一行 + num_samples = min(glGetIntegerv(GL_MAX_SAMPLES), 4) # No more than GL_MAX_SAMPLES + + # 其实就是把 4 替换成 num_samples,其余不变 + glRenderbufferStorageMultisample(GL_RENDERBUFFER, num_samples, GL_RGBA, self.viewport_width, self.viewport_height) + + glBindRenderbuffer(GL_RENDERBUFFER, self._main_db_ms) # 这行不变 + + # 这一行也是将 4 替换成 num_samples + glRenderbufferStorageMultisample(GL_RENDERBUFFER, num_samples, GL_DEPTH_COMPONENT24, self.viewport_width, self.viewport_height) + + self._main_fb_ms = glGenFramebuffers(1) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb_ms) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_RENDERBUFFER, self._main_cb_ms + ) + glFramebufferRenderbuffer( + GL_DRAW_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + GL_RENDERBUFFER, self._main_db_ms + ) + + self._main_fb_dims = (self.viewport_width, self.viewport_height) + + def _delete_main_framebuffer(self): + if self._main_fb is not None: + glDeleteFramebuffers(2, [self._main_fb, self._main_fb_ms]) + if self._main_cb is not None: + glDeleteRenderbuffers(2, [self._main_cb, self._main_cb_ms]) + if self._main_db is not None: + glDeleteRenderbuffers(2, [self._main_db, self._main_db_ms]) + + self._main_fb = None + self._main_cb = None + self._main_db = None + self._main_fb_ms = None + self._main_cb_ms = None + self._main_db_ms = None + self._main_fb_dims = (None, None) + + def _read_main_framebuffer(self, scene, flags): + width, height = self._main_fb_dims[0], self._main_fb_dims[1] + + # Bind framebuffer and blit buffers + glBindFramebuffer(GL_READ_FRAMEBUFFER, self._main_fb_ms) + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, self._main_fb) + glBlitFramebuffer( + 0, 0, width, height, 0, 0, width, height, + GL_COLOR_BUFFER_BIT, GL_LINEAR + ) + glBlitFramebuffer( + 0, 0, width, height, 0, 0, width, height, + GL_DEPTH_BUFFER_BIT, GL_NEAREST + ) + glBindFramebuffer(GL_READ_FRAMEBUFFER, self._main_fb) + + # Read depth + depth_buf = glReadPixels( + 0, 0, width, height, GL_DEPTH_COMPONENT, GL_FLOAT + ) + depth_im = np.frombuffer(depth_buf, dtype=np.float32) + depth_im = depth_im.reshape((height, width)) + depth_im = np.flip(depth_im, axis=0) + inf_inds = (depth_im == 1.0) + depth_im = 2.0 * depth_im - 1.0 + z_near = scene.main_camera_node.camera.znear + z_far = scene.main_camera_node.camera.zfar + noninf = np.logical_not(inf_inds) + if z_far is None: + depth_im[noninf] = 2 * z_near / (1.0 - depth_im[noninf]) + else: + depth_im[noninf] = ((2.0 * z_near * z_far) / + (z_far + z_near - depth_im[noninf] * + (z_far - z_near))) + depth_im[inf_inds] = 0.0 + + # Resize for macos if needed + if sys.platform == 'darwin': + depth_im = self._resize_image(depth_im) + + if flags & RenderFlags.DEPTH_ONLY: + return depth_im + + # Read color + if flags & RenderFlags.RGBA: + color_buf = glReadPixels( + 0, 0, width, height, GL_RGBA, GL_UNSIGNED_BYTE + ) + color_im = np.frombuffer(color_buf, dtype=np.uint8) + color_im = color_im.reshape((height, width, 4)) + else: + color_buf = glReadPixels( + 0, 0, width, height, GL_RGB, GL_UNSIGNED_BYTE + ) + color_im = np.frombuffer(color_buf, dtype=np.uint8) + color_im = color_im.reshape((height, width, 3)) + color_im = np.flip(color_im, axis=0) + + # Resize for macos if needed + if sys.platform == 'darwin': + color_im = self._resize_image(color_im, True) + + return color_im, depth_im + + def _resize_image(self, value, antialias=False): + """If needed, rescale the render for MacOS.""" + img = PIL.Image.fromarray(value) + resample = PIL.Image.NEAREST + if antialias: + resample = PIL.Image.BILINEAR + size = (self.viewport_width // self.dpscale, + self.viewport_height // self.dpscale) + img = img.resize(size, resample=resample) + return np.array(img) + + ########################################################################### + # Shadowmap Debugging + ########################################################################### + + def _forward_pass_no_reset(self, scene, flags): + # Set up camera matrices + V, P = self._get_camera_matrices(scene) + + # Now, render each object in sorted order + for node in self._sorted_mesh_nodes(scene): + mesh = node.mesh + + # Skip the mesh if it's not visible + if not mesh.is_visible: + continue + + for primitive in mesh.primitives: + + # First, get and bind the appropriate program + program = self._get_primitive_program( + primitive, flags, ProgramFlags.USE_MATERIAL + ) + program._bind() + + # Set the camera uniforms + program.set_uniform('V', V) + program.set_uniform('P', P) + program.set_uniform( + 'cam_pos', scene.get_pose(scene.main_camera_node)[:3,3] + ) + + # Next, bind the lighting + if not flags & RenderFlags.DEPTH_ONLY and not flags & RenderFlags.FLAT: + self._bind_lighting(scene, program, node, flags) + + # Finally, bind and draw the primitive + self._bind_and_draw_primitive( + primitive=primitive, + pose=scene.get_pose(node), + program=program, + flags=flags + ) + self._reset_active_textures() + + # Unbind the shader and flush the output + if program is not None: + program._unbind() + glFlush() + + def _render_light_shadowmaps(self, scene, light_nodes, flags, tile=False): + glBindFramebuffer(GL_DRAW_FRAMEBUFFER, 0) + glClearColor(*scene.bg_color) + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + glEnable(GL_DEPTH_TEST) + glDepthMask(GL_TRUE) + glDepthFunc(GL_LESS) + glDepthRange(0.0, 1.0) + + w = self.viewport_width + h = self.viewport_height + + num_nodes = len(light_nodes) + viewport_dims = { + (0, 2): [0, h // 2, w // 2, h], + (1, 2): [w // 2, h // 2, w, h], + (0, 3): [0, h // 2, w // 2, h], + (1, 3): [w // 2, h // 2, w, h], + (2, 3): [0, 0, w // 2, h // 2], + (0, 4): [0, h // 2, w // 2, h], + (1, 4): [w // 2, h // 2, w, h], + (2, 4): [0, 0, w // 2, h // 2], + (3, 4): [w // 2, 0, w, h // 2] + } + + if tile: + for i, ln in enumerate(light_nodes): + light = ln.light + + if light.shadow_texture is None: + raise ValueError('Light does not have a shadow texture') + + glViewport(*viewport_dims[(i, num_nodes + 1)]) + + program = self._get_debug_quad_program() + program._bind() + self._bind_texture(light.shadow_texture, 'depthMap', program) + self._render_debug_quad() + self._reset_active_textures() + glFlush() + i += 1 + glViewport(*viewport_dims[(i, num_nodes + 1)]) + self._forward_pass_no_reset(scene, flags) + else: + for i, ln in enumerate(light_nodes): + light = ln.light + + if light.shadow_texture is None: + raise ValueError('Light does not have a shadow texture') + + glViewport(0, 0, self.viewport_width, self.viewport_height) + + program = self._get_debug_quad_program() + program._bind() + self._bind_texture(light.shadow_texture, 'depthMap', program) + self._render_debug_quad() + self._reset_active_textures() + glFlush() + return + + def _get_debug_quad_program(self): + program = self._program_cache.get_program( + vertex_shader='debug_quad.vert', + fragment_shader='debug_quad.frag' + ) + if not program._in_context(): + program._add_to_context() + return program + + def _render_debug_quad(self): + x = glGenVertexArrays(1) + glBindVertexArray(x) + glDrawArrays(GL_TRIANGLES, 0, 6) + glBindVertexArray(0) + glDeleteVertexArrays(1, [x]) diff --git a/pyrender/pyrender/sampler.py b/pyrender/pyrender/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..e4784d068f808a40a56c8e748d83175f7f4e6233 --- /dev/null +++ b/pyrender/pyrender/sampler.py @@ -0,0 +1,102 @@ +"""Samplers, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-sampler + +Author: Matthew Matl +""" +from .constants import GLTF + + +class Sampler(object): + """Texture sampler properties for filtering and wrapping modes. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + magFilter : int, optional + Magnification filter. Valid values: + - :attr:`.GLTF.NEAREST` + - :attr:`.GLTF.LINEAR` + minFilter : int, optional + Minification filter. Valid values: + - :attr:`.GLTF.NEAREST` + - :attr:`.GLTF.LINEAR` + - :attr:`.GLTF.NEAREST_MIPMAP_NEAREST` + - :attr:`.GLTF.LINEAR_MIPMAP_NEAREST` + - :attr:`.GLTF.NEAREST_MIPMAP_LINEAR` + - :attr:`.GLTF.LINEAR_MIPMAP_LINEAR` + wrapS : int, optional + S (U) wrapping mode. Valid values: + - :attr:`.GLTF.CLAMP_TO_EDGE` + - :attr:`.GLTF.MIRRORED_REPEAT` + - :attr:`.GLTF.REPEAT` + wrapT : int, optional + T (V) wrapping mode. Valid values: + - :attr:`.GLTF.CLAMP_TO_EDGE` + - :attr:`.GLTF.MIRRORED_REPEAT` + - :attr:`.GLTF.REPEAT` + """ + + def __init__(self, + name=None, + magFilter=None, + minFilter=None, + wrapS=GLTF.REPEAT, + wrapT=GLTF.REPEAT): + self.name = name + self.magFilter = magFilter + self.minFilter = minFilter + self.wrapS = wrapS + self.wrapT = wrapT + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def magFilter(self): + """int : Magnification filter type. + """ + return self._magFilter + + @magFilter.setter + def magFilter(self, value): + self._magFilter = value + + @property + def minFilter(self): + """int : Minification filter type. + """ + return self._minFilter + + @minFilter.setter + def minFilter(self, value): + self._minFilter = value + + @property + def wrapS(self): + """int : S (U) wrapping mode. + """ + return self._wrapS + + @wrapS.setter + def wrapS(self, value): + self._wrapS = value + + @property + def wrapT(self): + """int : T (V) wrapping mode. + """ + return self._wrapT + + @wrapT.setter + def wrapT(self, value): + self._wrapT = value diff --git a/pyrender/pyrender/scene.py b/pyrender/pyrender/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe057ec66f52f2dd9c1363aacf72a7c6cec4e6c --- /dev/null +++ b/pyrender/pyrender/scene.py @@ -0,0 +1,585 @@ +"""Scenes, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-scene + +Author: Matthew Matl +""" +import numpy as np +import networkx as nx +import trimesh + +from .mesh import Mesh +from .camera import Camera +from .light import Light, PointLight, DirectionalLight, SpotLight +from .node import Node +from .utils import format_color_vector + + +class Scene(object): + """A hierarchical scene graph. + + Parameters + ---------- + nodes : list of :class:`Node` + The set of all nodes in the scene. + bg_color : (4,) float, optional + Background color of scene. + ambient_light : (3,) float, optional + Color of ambient light. Defaults to no ambient light. + name : str, optional + The user-defined name of this object. + """ + + def __init__(self, + nodes=None, + bg_color=None, + ambient_light=None, + name=None): + + if bg_color is None: + bg_color = np.ones(4) + else: + bg_color = format_color_vector(bg_color, 4) + + if ambient_light is None: + ambient_light = np.zeros(3) + + if nodes is None: + nodes = set() + self._nodes = set() # Will be added at the end of this function + + self.bg_color = bg_color + self.ambient_light = ambient_light + self.name = name + + self._name_to_nodes = {} + self._obj_to_nodes = {} + self._obj_name_to_nodes = {} + self._mesh_nodes = set() + self._point_light_nodes = set() + self._spot_light_nodes = set() + self._directional_light_nodes = set() + self._camera_nodes = set() + self._main_camera_node = None + self._bounds = None + + # Transform tree + self._digraph = nx.DiGraph() + self._digraph.add_node('world') + self._path_cache = {} + + # Find root nodes and add them + if len(nodes) > 0: + node_parent_map = {n: None for n in nodes} + for node in nodes: + for child in node.children: + if node_parent_map[child] is not None: + raise ValueError('Nodes may not have more than ' + 'one parent') + node_parent_map[child] = node + for node in node_parent_map: + if node_parent_map[node] is None: + self.add_node(node) + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def nodes(self): + """set of :class:`Node` : Set of nodes in the scene. + """ + return self._nodes + + @property + def bg_color(self): + """(3,) float : The scene background color. + """ + return self._bg_color + + @bg_color.setter + def bg_color(self, value): + if value is None: + value = np.ones(4) + else: + value = format_color_vector(value, 4) + self._bg_color = value + + @property + def ambient_light(self): + """(3,) float : The ambient light in the scene. + """ + return self._ambient_light + + @ambient_light.setter + def ambient_light(self, value): + if value is None: + value = np.zeros(3) + else: + value = format_color_vector(value, 3) + self._ambient_light = value + + @property + def meshes(self): + """set of :class:`Mesh` : The meshes in the scene. + """ + return set([n.mesh for n in self.mesh_nodes]) + + @property + def mesh_nodes(self): + """set of :class:`Node` : The nodes containing meshes. + """ + return self._mesh_nodes + + @property + def lights(self): + """set of :class:`Light` : The lights in the scene. + """ + return self.point_lights | self.spot_lights | self.directional_lights + + @property + def light_nodes(self): + """set of :class:`Node` : The nodes containing lights. + """ + return (self.point_light_nodes | self.spot_light_nodes | + self.directional_light_nodes) + + @property + def point_lights(self): + """set of :class:`PointLight` : The point lights in the scene. + """ + return set([n.light for n in self.point_light_nodes]) + + @property + def point_light_nodes(self): + """set of :class:`Node` : The nodes containing point lights. + """ + return self._point_light_nodes + + @property + def spot_lights(self): + """set of :class:`SpotLight` : The spot lights in the scene. + """ + return set([n.light for n in self.spot_light_nodes]) + + @property + def spot_light_nodes(self): + """set of :class:`Node` : The nodes containing spot lights. + """ + return self._spot_light_nodes + + @property + def directional_lights(self): + """set of :class:`DirectionalLight` : The directional lights in + the scene. + """ + return set([n.light for n in self.directional_light_nodes]) + + @property + def directional_light_nodes(self): + """set of :class:`Node` : The nodes containing directional lights. + """ + return self._directional_light_nodes + + @property + def cameras(self): + """set of :class:`Camera` : The cameras in the scene. + """ + return set([n.camera for n in self.camera_nodes]) + + @property + def camera_nodes(self): + """set of :class:`Node` : The nodes containing cameras in the scene. + """ + return self._camera_nodes + + @property + def main_camera_node(self): + """set of :class:`Node` : The node containing the main camera in the + scene. + """ + return self._main_camera_node + + @main_camera_node.setter + def main_camera_node(self, value): + if value not in self.nodes: + raise ValueError('New main camera node must already be in scene') + self._main_camera_node = value + + @property + def bounds(self): + """(2,3) float : The axis-aligned bounds of the scene. + """ + if self._bounds is None: + # Compute corners + corners = [] + for mesh_node in self.mesh_nodes: + mesh = mesh_node.mesh + pose = self.get_pose(mesh_node) + corners_local = trimesh.bounds.corners(mesh.bounds) + corners_world = pose[:3,:3].dot(corners_local.T).T + pose[:3,3] + corners.append(corners_world) + if len(corners) == 0: + self._bounds = np.zeros((2,3)) + else: + corners = np.vstack(corners) + self._bounds = np.array([np.min(corners, axis=0), + np.max(corners, axis=0)]) + return self._bounds + + @property + def centroid(self): + """(3,) float : The centroid of the scene's axis-aligned bounding box + (AABB). + """ + return np.mean(self.bounds, axis=0) + + @property + def extents(self): + """(3,) float : The lengths of the axes of the scene's AABB. + """ + return np.diff(self.bounds, axis=0).reshape(-1) + + @property + def scale(self): + """(3,) float : The length of the diagonal of the scene's AABB. + """ + return np.linalg.norm(self.extents) + + def add(self, obj, name=None, pose=None, + parent_node=None, parent_name=None): + """Add an object (mesh, light, or camera) to the scene. + + Parameters + ---------- + obj : :class:`Mesh`, :class:`Light`, or :class:`Camera` + The object to add to the scene. + name : str + A name for the new node to be created. + pose : (4,4) float + The local pose of this node relative to its parent node. + parent_node : :class:`Node` + The parent of this Node. If None, the new node is a root node. + parent_name : str + The name of the parent node, can be specified instead of + `parent_node`. + + Returns + ------- + node : :class:`Node` + The newly-created and inserted node. + """ + if isinstance(obj, Mesh): + node = Node(name=name, matrix=pose, mesh=obj) + elif isinstance(obj, Light): + node = Node(name=name, matrix=pose, light=obj) + elif isinstance(obj, Camera): + node = Node(name=name, matrix=pose, camera=obj) + else: + raise TypeError('Unrecognized object type') + + if parent_node is None and parent_name is not None: + parent_nodes = self.get_nodes(name=parent_name) + if len(parent_nodes) == 0: + raise ValueError('No parent node with name {} found' + .format(parent_name)) + elif len(parent_nodes) > 1: + raise ValueError('More than one parent node with name {} found' + .format(parent_name)) + parent_node = list(parent_nodes)[0] + + self.add_node(node, parent_node=parent_node) + + return node + + def get_nodes(self, node=None, name=None, obj=None, obj_name=None): + """Search for existing nodes. Only nodes matching all specified + parameters is returned, or None if no such node exists. + + Parameters + ---------- + node : :class:`Node`, optional + If present, returns this node if it is in the scene. + name : str + A name for the Node. + obj : :class:`Mesh`, :class:`Light`, or :class:`Camera` + An object that is attached to the node. + obj_name : str + The name of an object that is attached to the node. + + Returns + ------- + nodes : set of :class:`.Node` + The nodes that match all query terms. + """ + if node is not None: + if node in self.nodes: + return set([node]) + else: + return set() + nodes = set(self.nodes) + if name is not None: + matches = set() + if name in self._name_to_nodes: + matches = self._name_to_nodes[name] + nodes = nodes & matches + if obj is not None: + matches = set() + if obj in self._obj_to_nodes: + matches = self._obj_to_nodes[obj] + nodes = nodes & matches + if obj_name is not None: + matches = set() + if obj_name in self._obj_name_to_nodes: + matches = self._obj_name_to_nodes[obj_name] + nodes = nodes & matches + + return nodes + + def add_node(self, node, parent_node=None): + """Add a Node to the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be added. + parent_node : :class:`Node` + The parent of this Node. If None, the new node is a root node. + """ + if node in self.nodes: + raise ValueError('Node already in scene') + self.nodes.add(node) + + # Add node to sets + if node.name is not None: + if node.name not in self._name_to_nodes: + self._name_to_nodes[node.name] = set() + self._name_to_nodes[node.name].add(node) + for obj in [node.mesh, node.camera, node.light]: + if obj is not None: + if obj not in self._obj_to_nodes: + self._obj_to_nodes[obj] = set() + self._obj_to_nodes[obj].add(node) + if obj.name is not None: + if obj.name not in self._obj_name_to_nodes: + self._obj_name_to_nodes[obj.name] = set() + self._obj_name_to_nodes[obj.name].add(node) + if node.mesh is not None: + self._mesh_nodes.add(node) + if node.light is not None: + if isinstance(node.light, PointLight): + self._point_light_nodes.add(node) + if isinstance(node.light, SpotLight): + self._spot_light_nodes.add(node) + if isinstance(node.light, DirectionalLight): + self._directional_light_nodes.add(node) + if node.camera is not None: + self._camera_nodes.add(node) + if self._main_camera_node is None: + self._main_camera_node = node + + if parent_node is None: + parent_node = 'world' + elif parent_node not in self.nodes: + raise ValueError('Parent node must already be in scene') + elif node not in parent_node.children: + parent_node.children.append(node) + + # Create node in graph + self._digraph.add_node(node) + self._digraph.add_edge(node, parent_node) + + # Iterate over children + for child in node.children: + self.add_node(child, node) + + self._path_cache = {} + self._bounds = None + + def has_node(self, node): + """Check if a node is already in the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be checked. + + Returns + ------- + has_node : bool + True if the node is already in the scene and false otherwise. + """ + return node in self.nodes + + def remove_node(self, node): + """Remove a node and all its children from the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be removed. + """ + # Disconnect self from parent who is staying in the graph + parent = list(self._digraph.neighbors(node))[0] + self._remove_node(node) + if isinstance(parent, Node): + parent.children.remove(node) + self._path_cache = {} + self._bounds = None + + def get_pose(self, node): + """Get the world-frame pose of a node in the scene. + + Parameters + ---------- + node : :class:`Node` + The node to find the pose of. + + Returns + ------- + pose : (4,4) float + The transform matrix for this node. + """ + if node not in self.nodes: + raise ValueError('Node must already be in scene') + if node in self._path_cache: + path = self._path_cache[node] + else: + # Get path from from_frame to to_frame + path = nx.shortest_path(self._digraph, node, 'world') + self._path_cache[node] = path + + # Traverse from from_node to to_node + pose = np.eye(4) + for n in path[:-1]: + pose = np.dot(n.matrix, pose) + + return pose + + def set_pose(self, node, pose): + """Set the local-frame pose of a node in the scene. + + Parameters + ---------- + node : :class:`Node` + The node to set the pose of. + pose : (4,4) float + The pose to set the node to. + """ + if node not in self.nodes: + raise ValueError('Node must already be in scene') + node._matrix = pose + if node.mesh is not None: + self._bounds = None + + def clear(self): + """Clear out all nodes to form an empty scene. + """ + self._nodes = set() + + self._name_to_nodes = {} + self._obj_to_nodes = {} + self._obj_name_to_nodes = {} + self._mesh_nodes = set() + self._point_light_nodes = set() + self._spot_light_nodes = set() + self._directional_light_nodes = set() + self._camera_nodes = set() + self._main_camera_node = None + self._bounds = None + + # Transform tree + self._digraph = nx.DiGraph() + self._digraph.add_node('world') + self._path_cache = {} + + def _remove_node(self, node): + """Remove a node and all its children from the scene. + + Parameters + ---------- + node : :class:`Node` + The node to be removed. + """ + + # Remove self from nodes + self.nodes.remove(node) + + # Remove children + for child in node.children: + self._remove_node(child) + + # Remove self from the graph + self._digraph.remove_node(node) + + # Remove from maps + if node.name in self._name_to_nodes: + self._name_to_nodes[node.name].remove(node) + if len(self._name_to_nodes[node.name]) == 0: + self._name_to_nodes.pop(node.name) + for obj in [node.mesh, node.camera, node.light]: + if obj is None: + continue + self._obj_to_nodes[obj].remove(node) + if len(self._obj_to_nodes[obj]) == 0: + self._obj_to_nodes.pop(obj) + if obj.name is not None: + self._obj_name_to_nodes[obj.name].remove(node) + if len(self._obj_name_to_nodes[obj.name]) == 0: + self._obj_name_to_nodes.pop(obj.name) + if node.mesh is not None: + self._mesh_nodes.remove(node) + if node.light is not None: + if isinstance(node.light, PointLight): + self._point_light_nodes.remove(node) + if isinstance(node.light, SpotLight): + self._spot_light_nodes.remove(node) + if isinstance(node.light, DirectionalLight): + self._directional_light_nodes.remove(node) + if node.camera is not None: + self._camera_nodes.remove(node) + if self._main_camera_node == node: + if len(self._camera_nodes) > 0: + self._main_camera_node = next(iter(self._camera_nodes)) + else: + self._main_camera_node = None + + @staticmethod + def from_trimesh_scene(trimesh_scene, + bg_color=None, ambient_light=None): + """Create a :class:`.Scene` from a :class:`trimesh.scene.scene.Scene`. + + Parameters + ---------- + trimesh_scene : :class:`trimesh.scene.scene.Scene` + Scene with :class:~`trimesh.base.Trimesh` objects. + bg_color : (4,) float + Background color for the created scene. + ambient_light : (3,) float or None + Ambient light in the scene. + + Returns + ------- + scene_pr : :class:`Scene` + A scene containing the same geometry as the trimesh scene. + """ + # convert trimesh geometries to pyrender geometries + geometries = {name: Mesh.from_trimesh(geom) + for name, geom in trimesh_scene.geometry.items()} + + # create the pyrender scene object + scene_pr = Scene(bg_color=bg_color, ambient_light=ambient_light) + + # add every node with geometry to the pyrender scene + for node in trimesh_scene.graph.nodes_geometry: + pose, geom_name = trimesh_scene.graph[node] + scene_pr.add(geometries[geom_name], pose=pose) + + return scene_pr diff --git a/pyrender/pyrender/shader_program.py b/pyrender/pyrender/shader_program.py new file mode 100644 index 0000000000000000000000000000000000000000..c1803f280c98033abe0769771a9ad8ecfec942e3 --- /dev/null +++ b/pyrender/pyrender/shader_program.py @@ -0,0 +1,283 @@ +"""OpenGL shader program wrapper. +""" +import numpy as np +import os +import re + +import OpenGL +from OpenGL.GL import * +from OpenGL.GL import shaders as gl_shader_utils + + +class ShaderProgramCache(object): + """A cache for shader programs. + """ + + def __init__(self, shader_dir=None): + self._program_cache = {} + self.shader_dir = shader_dir + if self.shader_dir is None: + base_dir, _ = os.path.split(os.path.realpath(__file__)) + self.shader_dir = os.path.join(base_dir, 'shaders') + + def get_program(self, vertex_shader, fragment_shader, + geometry_shader=None, defines=None): + """Get a program via a list of shader files to include in the program. + + Parameters + ---------- + vertex_shader : str + The vertex shader filename. + fragment_shader : str + The fragment shader filename. + geometry_shader : str + The geometry shader filename. + defines : dict + Defines and their values for the shader. + + Returns + ------- + program : :class:`.ShaderProgram` + The program. + """ + shader_names = [] + if defines is None: + defines = {} + shader_filenames = [ + x for x in [vertex_shader, fragment_shader, geometry_shader] + if x is not None + ] + for fn in shader_filenames: + if fn is None: + continue + _, name = os.path.split(fn) + shader_names.append(name) + cid = OpenGL.contextdata.getContext() + key = tuple([cid] + sorted( + [(s,1) for s in shader_names] + [(d, defines[d]) for d in defines] + )) + + if key not in self._program_cache: + shader_filenames = [ + os.path.join(self.shader_dir, fn) for fn in shader_filenames + ] + if len(shader_filenames) == 2: + shader_filenames.append(None) + vs, fs, gs = shader_filenames + self._program_cache[key] = ShaderProgram( + vertex_shader=vs, fragment_shader=fs, + geometry_shader=gs, defines=defines + ) + return self._program_cache[key] + + def clear(self): + for key in self._program_cache: + self._program_cache[key].delete() + self._program_cache = {} + + +class ShaderProgram(object): + """A thin wrapper about OpenGL shader programs that supports easy creation, + binding, and uniform-setting. + + Parameters + ---------- + vertex_shader : str + The vertex shader filename. + fragment_shader : str + The fragment shader filename. + geometry_shader : str + The geometry shader filename. + defines : dict + Defines and their values for the shader. + """ + + def __init__(self, vertex_shader, fragment_shader, + geometry_shader=None, defines=None): + + self.vertex_shader = vertex_shader + self.fragment_shader = fragment_shader + self.geometry_shader = geometry_shader + + self.defines = defines + if self.defines is None: + self.defines = {} + + self._program_id = None + self._vao_id = None # PYOPENGL BUG + + # DEBUG + # self._unif_map = {} + + def _add_to_context(self): + if self._program_id is not None: + raise ValueError('Shader program already in context') + shader_ids = [] + + # Load vert shader + shader_ids.append(gl_shader_utils.compileShader( + self._load(self.vertex_shader), GL_VERTEX_SHADER) + ) + # Load frag shader + shader_ids.append(gl_shader_utils.compileShader( + self._load(self.fragment_shader), GL_FRAGMENT_SHADER) + ) + # Load geometry shader + if self.geometry_shader is not None: + shader_ids.append(gl_shader_utils.compileShader( + self._load(self.geometry_shader), GL_GEOMETRY_SHADER) + ) + + # Bind empty VAO PYOPENGL BUG + if self._vao_id is None: + self._vao_id = glGenVertexArrays(1) + glBindVertexArray(self._vao_id) + + # Compile program + self._program_id = gl_shader_utils.compileProgram(*shader_ids) + + # Unbind empty VAO PYOPENGL BUG + glBindVertexArray(0) + + def _in_context(self): + return self._program_id is not None + + def _remove_from_context(self): + if self._program_id is not None: + glDeleteProgram(self._program_id) + glDeleteVertexArrays(1, [self._vao_id]) + self._program_id = None + self._vao_id = None + + def _load(self, shader_filename): + path, _ = os.path.split(shader_filename) + + with open(shader_filename) as f: + text = f.read() + + def ifdef(matchobj): + if matchobj.group(1) in self.defines: + return '#if 1' + else: + return '#if 0' + + def ifndef(matchobj): + if matchobj.group(1) in self.defines: + return '#if 0' + else: + return '#if 1' + + ifdef_regex = re.compile( + '#ifdef\\s+([a-zA-Z_][a-zA-Z_0-9]*)\\s*$', re.MULTILINE + ) + ifndef_regex = re.compile( + '#ifndef\\s+([a-zA-Z_][a-zA-Z_0-9]*)\\s*$', re.MULTILINE + ) + text = re.sub(ifdef_regex, ifdef, text) + text = re.sub(ifndef_regex, ifndef, text) + + for define in self.defines: + value = str(self.defines[define]) + text = text.replace(define, value) + + return text + + def _bind(self): + """Bind this shader program to the current OpenGL context. + """ + if self._program_id is None: + raise ValueError('Cannot bind program that is not in context') + # glBindVertexArray(self._vao_id) + glUseProgram(self._program_id) + + def _unbind(self): + """Unbind this shader program from the current OpenGL context. + """ + glUseProgram(0) + + def delete(self): + """Delete this shader program from the current OpenGL context. + """ + self._remove_from_context() + + def set_uniform(self, name, value, unsigned=False): + """Set a uniform value in the current shader program. + + Parameters + ---------- + name : str + Name of the uniform to set. + value : int, float, or ndarray + Value to set the uniform to. + unsigned : bool + If True, ints will be treated as unsigned values. + """ + try: + # DEBUG + # self._unif_map[name] = 1, (1,) + loc = glGetUniformLocation(self._program_id, name) + + if loc == -1: + raise ValueError('Invalid shader variable: {}'.format(name)) + + if isinstance(value, np.ndarray): + # DEBUG + # self._unif_map[name] = value.size, value.shape + if value.ndim == 1: + if (np.issubdtype(value.dtype, np.unsignedinteger) or + unsigned): + dtype = 'u' + value = value.astype(np.uint32) + elif np.issubdtype(value.dtype, np.integer): + dtype = 'i' + value = value.astype(np.int32) + else: + dtype = 'f' + value = value.astype(np.float32) + self._FUNC_MAP[(value.shape[0], dtype)](loc, 1, value) + else: + self._FUNC_MAP[(value.shape[0], value.shape[1])]( + loc, 1, GL_TRUE, value + ) + + # Call correct uniform function + elif isinstance(value, float): + glUniform1f(loc, value) + elif isinstance(value, int): + if unsigned: + glUniform1ui(loc, value) + else: + glUniform1i(loc, value) + elif isinstance(value, bool): + if unsigned: + glUniform1ui(loc, int(value)) + else: + glUniform1i(loc, int(value)) + else: + raise ValueError('Invalid data type') + except Exception: + pass + + _FUNC_MAP = { + (1,'u'): glUniform1uiv, + (2,'u'): glUniform2uiv, + (3,'u'): glUniform3uiv, + (4,'u'): glUniform4uiv, + (1,'i'): glUniform1iv, + (2,'i'): glUniform2iv, + (3,'i'): glUniform3iv, + (4,'i'): glUniform4iv, + (1,'f'): glUniform1fv, + (2,'f'): glUniform2fv, + (3,'f'): glUniform3fv, + (4,'f'): glUniform4fv, + (2,2): glUniformMatrix2fv, + (2,3): glUniformMatrix2x3fv, + (2,4): glUniformMatrix2x4fv, + (3,2): glUniformMatrix3x2fv, + (3,3): glUniformMatrix3fv, + (3,4): glUniformMatrix3x4fv, + (4,2): glUniformMatrix4x2fv, + (4,3): glUniformMatrix4x3fv, + (4,4): glUniformMatrix4fv, + } diff --git a/pyrender/pyrender/shaders/debug_quad.frag b/pyrender/pyrender/shaders/debug_quad.frag new file mode 100644 index 0000000000000000000000000000000000000000..4647bb50dfa1e4510e2d4afb37959c7f57532eca --- /dev/null +++ b/pyrender/pyrender/shaders/debug_quad.frag @@ -0,0 +1,23 @@ +#version 330 core +out vec4 FragColor; + +in vec2 TexCoords; + +uniform sampler2D depthMap; +//uniform float near_plane; +//uniform float far_plane; +// +//// required when using a perspective projection matrix +//float LinearizeDepth(float depth) +//{ +// float z = depth * 2.0 - 1.0; // Back to NDC +// return (2.0 * near_plane * far_plane) / (far_plane + near_plane - z * (far_plane - near_plane)); +//} + +void main() +{ + float depthValue = texture(depthMap, TexCoords).r; + // FragColor = vec4(vec3(LinearizeDepth(depthValue) / far_plane), 1.0); // perspective + FragColor = vec4(vec3(depthValue), 1.0); // orthographic + //FragColor = vec4(1.0, 1.0, 0.0, 1.0); +} diff --git a/pyrender/pyrender/shaders/debug_quad.vert b/pyrender/pyrender/shaders/debug_quad.vert new file mode 100644 index 0000000000000000000000000000000000000000..d2f2fcb7626f6c22e0d52bf4d6c91251cbdb9f52 --- /dev/null +++ b/pyrender/pyrender/shaders/debug_quad.vert @@ -0,0 +1,25 @@ +#version 330 core +//layout (location = 0) in vec3 aPos; +//layout (location = 1) in vec2 aTexCoords; +// +//out vec2 TexCoords; +// +//void main() +//{ +// TexCoords = aTexCoords; +// gl_Position = vec4(aPos, 1.0); +//} +// +// +//layout(location = 0) out vec2 uv; + +out vec2 TexCoords; + +void main() +{ + float x = float(((uint(gl_VertexID) + 2u) / 3u)%2u); + float y = float(((uint(gl_VertexID) + 1u) / 3u)%2u); + + gl_Position = vec4(-1.0f + x*2.0f, -1.0f+y*2.0f, 0.0f, 1.0f); + TexCoords = vec2(x, y); +} diff --git a/pyrender/pyrender/shaders/flat.frag b/pyrender/pyrender/shaders/flat.frag new file mode 100644 index 0000000000000000000000000000000000000000..7ec01c6d095ec5dacc693accd3ad507ced61a79a --- /dev/null +++ b/pyrender/pyrender/shaders/flat.frag @@ -0,0 +1,126 @@ +#version 330 core +/////////////////////////////////////////////////////////////////////////////// +// Structs +/////////////////////////////////////////////////////////////////////////////// + +struct Material { + vec3 emissive_factor; + +#ifdef USE_METALLIC_MATERIAL + vec4 base_color_factor; + float metallic_factor; + float roughness_factor; +#endif + +#ifdef USE_GLOSSY_MATERIAL + vec4 diffuse_factor; + vec3 specular_factor; + float glossiness_factor; +#endif + +#ifdef HAS_NORMAL_TEX + sampler2D normal_texture; +#endif +#ifdef HAS_OCCLUSION_TEX + sampler2D occlusion_texture; +#endif +#ifdef HAS_EMISSIVE_TEX + sampler2D emissive_texture; +#endif +#ifdef HAS_BASE_COLOR_TEX + sampler2D base_color_texture; +#endif +#ifdef HAS_METALLIC_ROUGHNESS_TEX + sampler2D metallic_roughness_texture; +#endif +#ifdef HAS_DIFFUSE_TEX + sampler2D diffuse_texture; +#endif +#ifdef HAS_SPECULAR_GLOSSINESS_TEX + sampler2D specular_glossiness; +#endif +}; + +/////////////////////////////////////////////////////////////////////////////// +// Uniforms +/////////////////////////////////////////////////////////////////////////////// +uniform Material material; +uniform vec3 cam_pos; + +#ifdef USE_IBL +uniform samplerCube diffuse_env; +uniform samplerCube specular_env; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Inputs +/////////////////////////////////////////////////////////////////////////////// + +in vec3 frag_position; +#ifdef NORMAL_LOC +in vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +in mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +in vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +in vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +in vec4 color_multiplier; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// OUTPUTS +/////////////////////////////////////////////////////////////////////////////// + +out vec4 frag_color; + +/////////////////////////////////////////////////////////////////////////////// +// Constants +/////////////////////////////////////////////////////////////////////////////// +const float PI = 3.141592653589793; +const float min_roughness = 0.04; + +/////////////////////////////////////////////////////////////////////////////// +// Utility Functions +/////////////////////////////////////////////////////////////////////////////// +vec4 srgb_to_linear(vec4 srgb) +{ +#ifndef SRGB_CORRECTED + // Fast Approximation + //vec3 linOut = pow(srgbIn.xyz,vec3(2.2)); + // + vec3 b_less = step(vec3(0.04045),srgb.xyz); + vec3 lin_out = mix( srgb.xyz/vec3(12.92), pow((srgb.xyz+vec3(0.055))/vec3(1.055),vec3(2.4)), b_less ); + return vec4(lin_out, srgb.w); +#else + return srgb; +#endif +} + +/////////////////////////////////////////////////////////////////////////////// +// MAIN +/////////////////////////////////////////////////////////////////////////////// +void main() +{ + + // Compute albedo + vec4 base_color = material.base_color_factor; +#ifdef HAS_BASE_COLOR_TEX + base_color = base_color * texture(material.base_color_texture, uv_0); +#endif + +#ifdef COLOR_0_LOC + base_color *= color_multiplier; +#endif + + frag_color = clamp(base_color, 0.0, 1.0); +} diff --git a/pyrender/pyrender/shaders/flat.vert b/pyrender/pyrender/shaders/flat.vert new file mode 100644 index 0000000000000000000000000000000000000000..cfd241c3544718a261f961c3aa3c03aa13c97761 --- /dev/null +++ b/pyrender/pyrender/shaders/flat.vert @@ -0,0 +1,86 @@ +#version 330 core + +// Vertex Attributes +layout(location = 0) in vec3 position; +#ifdef NORMAL_LOC +layout(location = NORMAL_LOC) in vec3 normal; +#endif +#ifdef TANGENT_LOC +layout(location = TANGENT_LOC) in vec4 tangent; +#endif +#ifdef TEXCOORD_0_LOC +layout(location = TEXCOORD_0_LOC) in vec2 texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC +layout(location = TEXCOORD_1_LOC) in vec2 texcoord_1; +#endif +#ifdef COLOR_0_LOC +layout(location = COLOR_0_LOC) in vec4 color_0; +#endif +#ifdef JOINTS_0_LOC +layout(location = JOINTS_0_LOC) in vec4 joints_0; +#endif +#ifdef WEIGHTS_0_LOC +layout(location = WEIGHTS_0_LOC) in vec4 weights_0; +#endif +layout(location = INST_M_LOC) in mat4 inst_m; + +// Uniforms +uniform mat4 M; +uniform mat4 V; +uniform mat4 P; + +// Outputs +out vec3 frag_position; +#ifdef NORMAL_LOC +out vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +out mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +out vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +out vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +out vec4 color_multiplier; +#endif + + +void main() +{ + gl_Position = P * V * M * inst_m * vec4(position, 1); + frag_position = vec3(M * inst_m * vec4(position, 1.0)); + + mat4 N = transpose(inverse(M * inst_m)); + +#ifdef NORMAL_LOC + frag_normal = normalize(vec3(N * vec4(normal, 0.0))); +#endif + +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC + vec3 normal_w = normalize(vec3(N * vec4(normal, 0.0))); + vec3 tangent_w = normalize(vec3(N * vec4(tangent.xyz, 0.0))); + vec3 bitangent_w = cross(normal_w, tangent_w) * tangent.w; + tbn = mat3(tangent_w, bitangent_w, normal_w); +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC + uv_0 = texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC + uv_1 = texcoord_1; +#endif +#ifdef COLOR_0_LOC + color_multiplier = color_0; +#endif +} diff --git a/pyrender/pyrender/shaders/mesh.frag b/pyrender/pyrender/shaders/mesh.frag new file mode 100644 index 0000000000000000000000000000000000000000..43187621b4388b18badf4e562a7ad300e59b029d --- /dev/null +++ b/pyrender/pyrender/shaders/mesh.frag @@ -0,0 +1,456 @@ +#version 330 core +/////////////////////////////////////////////////////////////////////////////// +// Structs +/////////////////////////////////////////////////////////////////////////////// + +struct SpotLight { + vec3 color; + float intensity; + float range; + vec3 position; + vec3 direction; + float light_angle_scale; + float light_angle_offset; + + #ifdef SPOT_LIGHT_SHADOWS + sampler2D shadow_map; + mat4 light_matrix; + #endif +}; + +struct DirectionalLight { + vec3 color; + float intensity; + vec3 direction; + + #ifdef DIRECTIONAL_LIGHT_SHADOWS + sampler2D shadow_map; + mat4 light_matrix; + #endif +}; + +struct PointLight { + vec3 color; + float intensity; + float range; + vec3 position; + + #ifdef POINT_LIGHT_SHADOWS + samplerCube shadow_map; + #endif +}; + +struct Material { + vec3 emissive_factor; + +#ifdef USE_METALLIC_MATERIAL + vec4 base_color_factor; + float metallic_factor; + float roughness_factor; +#endif + +#ifdef USE_GLOSSY_MATERIAL + vec4 diffuse_factor; + vec3 specular_factor; + float glossiness_factor; +#endif + +#ifdef HAS_NORMAL_TEX + sampler2D normal_texture; +#endif +#ifdef HAS_OCCLUSION_TEX + sampler2D occlusion_texture; +#endif +#ifdef HAS_EMISSIVE_TEX + sampler2D emissive_texture; +#endif +#ifdef HAS_BASE_COLOR_TEX + sampler2D base_color_texture; +#endif +#ifdef HAS_METALLIC_ROUGHNESS_TEX + sampler2D metallic_roughness_texture; +#endif +#ifdef HAS_DIFFUSE_TEX + sampler2D diffuse_texture; +#endif +#ifdef HAS_SPECULAR_GLOSSINESS_TEX + sampler2D specular_glossiness; +#endif +}; + +struct PBRInfo { + float nl; + float nv; + float nh; + float lh; + float vh; + float roughness; + float metallic; + vec3 f0; + vec3 c_diff; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Uniforms +/////////////////////////////////////////////////////////////////////////////// +uniform Material material; +uniform PointLight point_lights[MAX_POINT_LIGHTS]; +uniform int n_point_lights; +uniform DirectionalLight directional_lights[MAX_DIRECTIONAL_LIGHTS]; +uniform int n_directional_lights; +uniform SpotLight spot_lights[MAX_SPOT_LIGHTS]; +uniform int n_spot_lights; +uniform vec3 cam_pos; +uniform vec3 ambient_light; + +#ifdef USE_IBL +uniform samplerCube diffuse_env; +uniform samplerCube specular_env; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Inputs +/////////////////////////////////////////////////////////////////////////////// + +in vec3 frag_position; +#ifdef NORMAL_LOC +in vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +in mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +in vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +in vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +in vec4 color_multiplier; +#endif + +/////////////////////////////////////////////////////////////////////////////// +// OUTPUTS +/////////////////////////////////////////////////////////////////////////////// + +out vec4 frag_color; + +/////////////////////////////////////////////////////////////////////////////// +// Constants +/////////////////////////////////////////////////////////////////////////////// +const float PI = 3.141592653589793; +const float min_roughness = 0.04; + +/////////////////////////////////////////////////////////////////////////////// +// Utility Functions +/////////////////////////////////////////////////////////////////////////////// +vec4 srgb_to_linear(vec4 srgb) +{ +#ifndef SRGB_CORRECTED + // Fast Approximation + //vec3 linOut = pow(srgbIn.xyz,vec3(2.2)); + // + vec3 b_less = step(vec3(0.04045),srgb.xyz); + vec3 lin_out = mix( srgb.xyz/vec3(12.92), pow((srgb.xyz+vec3(0.055))/vec3(1.055),vec3(2.4)), b_less ); + return vec4(lin_out, srgb.w); +#else + return srgb; +#endif +} + +// Normal computation +vec3 get_normal() +{ +#ifdef HAS_NORMAL_TEX + +#ifndef HAS_TANGENTS + vec3 pos_dx = dFdx(frag_position); + vec3 pos_dy = dFdy(frag_position); + vec3 tex_dx = dFdx(vec3(uv_0, 0.0)); + vec3 tex_dy = dFdy(vec3(uv_0, 0.0)); + vec3 t = (tex_dy.t * pos_dx - tex_dx.t * pos_dy) / (tex_dx.s * tex_dy.t - tex_dy.s * tex_dx.t); + +#ifdef NORMAL_LOC + vec3 ng = normalize(frag_normal); +#else + vec3 = cross(pos_dx, pos_dy); +#endif + + t = normalize(t - ng * dot(ng, t)); + vec3 b = normalize(cross(ng, t)); + mat3 tbn_n = mat3(t, b, ng); + +#else + + mat3 tbn_n = tbn; + +#endif + + vec3 n = texture(material.normal_texture, uv_0).rgb; + n = normalize(tbn_n * ((2.0 * n - 1.0) * vec3(1.0, 1.0, 1.0))); + return n; // TODO NORMAL MAPPING + +#else + +#ifdef NORMAL_LOC + return frag_normal; +#else + return normalize(cam_pos - frag_position); +#endif + +#endif +} + +// Fresnel +vec3 specular_reflection(PBRInfo info) +{ + vec3 res = info.f0 + (1.0 - info.f0) * pow(clamp(1.0 - info.vh, 0.0, 1.0), 5.0); + return res; +} + +// Smith +float geometric_occlusion(PBRInfo info) +{ + float r = info.roughness + 1.0; + float k = r * r / 8.0; + float g1 = info.nv / (info.nv * (1.0 - k) + k); + float g2 = info.nl / (info.nl * (1.0 - k) + k); + //float k = info.roughness * sqrt(2.0 / PI); + //float g1 = info.lh / (info.lh * (1.0 - k) + k); + //float g2 = info.nh / (info.nh * (1.0 - k) + k); + return g1 * g2; +} + +float microfacet_distribution(PBRInfo info) +{ + float a = info.roughness * info.roughness; + float a2 = a * a; + float nh2 = info.nh * info.nh; + + float denom = (nh2 * (a2 - 1.0) + 1.0); + return a2 / (PI * denom * denom); +} + +vec3 compute_brdf(vec3 n, vec3 v, vec3 l, + float roughness, float metalness, + vec3 f0, vec3 c_diff, vec3 albedo, + vec3 radiance) +{ + vec3 h = normalize(l+v); + float nl = clamp(dot(n, l), 0.001, 1.0); + float nv = clamp(abs(dot(n, v)), 0.001, 1.0); + float nh = clamp(dot(n, h), 0.0, 1.0); + float lh = clamp(dot(l, h), 0.0, 1.0); + float vh = clamp(dot(v, h), 0.0, 1.0); + + PBRInfo info = PBRInfo(nl, nv, nh, lh, vh, roughness, metalness, f0, c_diff); + + // Compute PBR terms + vec3 F = specular_reflection(info); + float G = geometric_occlusion(info); + float D = microfacet_distribution(info); + + // Compute BRDF + vec3 diffuse_contrib = (1.0 - F) * c_diff / PI; + vec3 spec_contrib = F * G * D / (4.0 * nl * nv + 0.001); + + vec3 color = nl * radiance * (diffuse_contrib + spec_contrib); + return color; +} + +float texture2DCompare(sampler2D depths, vec2 uv, float compare) { + return compare > texture(depths, uv.xy).r ? 1.0 : 0.0; +} + +float texture2DShadowLerp(sampler2D depths, vec2 size, vec2 uv, float compare) { + vec2 texelSize = vec2(1.0)/size; + vec2 f = fract(uv*size+0.5); + vec2 centroidUV = floor(uv*size+0.5)/size; + + float lb = texture2DCompare(depths, centroidUV+texelSize*vec2(0.0, 0.0), compare); + float lt = texture2DCompare(depths, centroidUV+texelSize*vec2(0.0, 1.0), compare); + float rb = texture2DCompare(depths, centroidUV+texelSize*vec2(1.0, 0.0), compare); + float rt = texture2DCompare(depths, centroidUV+texelSize*vec2(1.0, 1.0), compare); + float a = mix(lb, lt, f.y); + float b = mix(rb, rt, f.y); + float c = mix(a, b, f.x); + return c; +} + +float PCF(sampler2D depths, vec2 size, vec2 uv, float compare){ + float result = 0.0; + for(int x=-1; x<=1; x++){ + for(int y=-1; y<=1; y++){ + vec2 off = vec2(x,y)/size; + result += texture2DShadowLerp(depths, size, uv+off, compare); + } + } + return result/9.0; +} + +float shadow_calc(mat4 light_matrix, sampler2D shadow_map, float nl) +{ + // Compute light texture UV coords + vec4 proj_coords = vec4(light_matrix * vec4(frag_position.xyz, 1.0)); + vec3 light_coords = proj_coords.xyz / proj_coords.w; + light_coords = light_coords * 0.5 + 0.5; + float current_depth = light_coords.z; + float bias = max(0.001 * (1.0 - nl), 0.0001) / proj_coords.w; + float compare = (current_depth - bias); + float shadow = PCF(shadow_map, textureSize(shadow_map, 0), light_coords.xy, compare); + if (light_coords.z > 1.0) { + shadow = 0.0; + } + return shadow; +} + +/////////////////////////////////////////////////////////////////////////////// +// MAIN +/////////////////////////////////////////////////////////////////////////////// +void main() +{ + + vec4 color = vec4(vec3(0.0), 1.0); +/////////////////////////////////////////////////////////////////////////////// +// Handle Metallic Materials +/////////////////////////////////////////////////////////////////////////////// +#ifdef USE_METALLIC_MATERIAL + + // Compute metallic/roughness factors + float roughness = material.roughness_factor; + float metallic = material.metallic_factor; +#ifdef HAS_METALLIC_ROUGHNESS_TEX + vec2 mr = texture(material.metallic_roughness_texture, uv_0).rg; + roughness = roughness * mr.r; + metallic = metallic * mr.g; +#endif + roughness = clamp(roughness, min_roughness, 1.0); + metallic = clamp(metallic, 0.0, 1.0); + // In convention, material roughness is perceputal roughness ^ 2 + float alpha_roughness = roughness * roughness; + + // Compute albedo + vec4 base_color = material.base_color_factor; +#ifdef HAS_BASE_COLOR_TEX + base_color = base_color * srgb_to_linear(texture(material.base_color_texture, uv_0)); +#endif + + // Compute specular and diffuse colors + vec3 dialectric_spec = vec3(min_roughness); + vec3 c_diff = mix(vec3(0.0), base_color.rgb * (1 - min_roughness), 1.0 - metallic); + vec3 f0 = mix(dialectric_spec, base_color.rgb, metallic); + + // Compute normal + vec3 n = normalize(get_normal()); + + // Loop over lights + for (int i = 0; i < n_directional_lights; i++) { + vec3 direction = directional_lights[i].direction; + vec3 v = normalize(cam_pos - frag_position); // Vector towards camera + vec3 l = normalize(-1.0 * direction); // Vector towards light + + // Compute attenuation and radiance + float attenuation = directional_lights[i].intensity; + vec3 radiance = attenuation * directional_lights[i].color; + + // Compute outbound color + vec3 res = compute_brdf(n, v, l, roughness, metallic, + f0, c_diff, base_color.rgb, radiance); + + // Compute shadow +#ifdef DIRECTIONAL_LIGHT_SHADOWS + float nl = clamp(dot(n,l), 0.0, 1.0); + float shadow = shadow_calc( + directional_lights[i].light_matrix, + directional_lights[i].shadow_map, + nl + ); + res = res * (1.0 - shadow); +#endif + color.xyz += res; + } + + for (int i = 0; i < n_point_lights; i++) { + vec3 position = point_lights[i].position; + vec3 v = normalize(cam_pos - frag_position); // Vector towards camera + vec3 l = normalize(position - frag_position); // Vector towards light + + // Compute attenuation and radiance + float dist = length(position - frag_position); + float attenuation = point_lights[i].intensity / (dist * dist); + vec3 radiance = attenuation * point_lights[i].color; + + // Compute outbound color + vec3 res = compute_brdf(n, v, l, roughness, metallic, + f0, c_diff, base_color.rgb, radiance); + color.xyz += res; + } + for (int i = 0; i < n_spot_lights; i++) { + vec3 position = spot_lights[i].position; + vec3 v = normalize(cam_pos - frag_position); // Vector towards camera + vec3 l = normalize(position - frag_position); // Vector towards light + + // Compute attenuation and radiance + vec3 direction = spot_lights[i].direction; + float las = spot_lights[i].light_angle_scale; + float lao = spot_lights[i].light_angle_offset; + float dist = length(position - frag_position); + float cd = clamp(dot(direction, -l), 0.0, 1.0); + float attenuation = clamp(cd * las + lao, 0.0, 1.0); + attenuation = attenuation * attenuation * spot_lights[i].intensity; + attenuation = attenuation / (dist * dist); + vec3 radiance = attenuation * spot_lights[i].color; + + // Compute outbound color + vec3 res = compute_brdf(n, v, l, roughness, metallic, + f0, c_diff, base_color.rgb, radiance); +#ifdef SPOT_LIGHT_SHADOWS + float nl = clamp(dot(n,l), 0.0, 1.0); + float shadow = shadow_calc( + spot_lights[i].light_matrix, + spot_lights[i].shadow_map, + nl + ); + res = res * (1.0 - shadow); +#endif + color.xyz += res; + } + color.xyz += base_color.xyz * ambient_light; + + // Calculate lighting from environment +#ifdef USE_IBL + // TODO +#endif + + // Apply occlusion +#ifdef HAS_OCCLUSION_TEX + float ao = texture(material.occlusion_texture, uv_0).r; + color.xyz *= ao; +#endif + + // Apply emissive map + vec3 emissive = material.emissive_factor; +#ifdef HAS_EMISSIVE_TEX + emissive *= srgb_to_linear(texture(material.emissive_texture, uv_0)).rgb; +#endif + color.xyz += emissive * material.emissive_factor; + +#ifdef COLOR_0_LOC + color *= color_multiplier; +#endif + + frag_color = clamp(vec4(pow(color.xyz, vec3(1.0/2.2)), color.a * base_color.a), 0.0, 1.0); + +#else + // TODO GLOSSY MATERIAL BRDF +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Handle Glossy Materials +/////////////////////////////////////////////////////////////////////////////// + +} diff --git a/pyrender/pyrender/shaders/mesh.vert b/pyrender/pyrender/shaders/mesh.vert new file mode 100644 index 0000000000000000000000000000000000000000..cfd241c3544718a261f961c3aa3c03aa13c97761 --- /dev/null +++ b/pyrender/pyrender/shaders/mesh.vert @@ -0,0 +1,86 @@ +#version 330 core + +// Vertex Attributes +layout(location = 0) in vec3 position; +#ifdef NORMAL_LOC +layout(location = NORMAL_LOC) in vec3 normal; +#endif +#ifdef TANGENT_LOC +layout(location = TANGENT_LOC) in vec4 tangent; +#endif +#ifdef TEXCOORD_0_LOC +layout(location = TEXCOORD_0_LOC) in vec2 texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC +layout(location = TEXCOORD_1_LOC) in vec2 texcoord_1; +#endif +#ifdef COLOR_0_LOC +layout(location = COLOR_0_LOC) in vec4 color_0; +#endif +#ifdef JOINTS_0_LOC +layout(location = JOINTS_0_LOC) in vec4 joints_0; +#endif +#ifdef WEIGHTS_0_LOC +layout(location = WEIGHTS_0_LOC) in vec4 weights_0; +#endif +layout(location = INST_M_LOC) in mat4 inst_m; + +// Uniforms +uniform mat4 M; +uniform mat4 V; +uniform mat4 P; + +// Outputs +out vec3 frag_position; +#ifdef NORMAL_LOC +out vec3 frag_normal; +#endif +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC +out mat3 tbn; +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC +out vec2 uv_0; +#endif +#ifdef TEXCOORD_1_LOC +out vec2 uv_1; +#endif +#ifdef COLOR_0_LOC +out vec4 color_multiplier; +#endif + + +void main() +{ + gl_Position = P * V * M * inst_m * vec4(position, 1); + frag_position = vec3(M * inst_m * vec4(position, 1.0)); + + mat4 N = transpose(inverse(M * inst_m)); + +#ifdef NORMAL_LOC + frag_normal = normalize(vec3(N * vec4(normal, 0.0))); +#endif + +#ifdef HAS_NORMAL_TEX +#ifdef TANGENT_LOC +#ifdef NORMAL_LOC + vec3 normal_w = normalize(vec3(N * vec4(normal, 0.0))); + vec3 tangent_w = normalize(vec3(N * vec4(tangent.xyz, 0.0))); + vec3 bitangent_w = cross(normal_w, tangent_w) * tangent.w; + tbn = mat3(tangent_w, bitangent_w, normal_w); +#endif +#endif +#endif +#ifdef TEXCOORD_0_LOC + uv_0 = texcoord_0; +#endif +#ifdef TEXCOORD_1_LOC + uv_1 = texcoord_1; +#endif +#ifdef COLOR_0_LOC + color_multiplier = color_0; +#endif +} diff --git a/pyrender/pyrender/shaders/mesh_depth.frag b/pyrender/pyrender/shaders/mesh_depth.frag new file mode 100644 index 0000000000000000000000000000000000000000..d8b1fac6091cfa457ba835ae0758e955f06d8754 --- /dev/null +++ b/pyrender/pyrender/shaders/mesh_depth.frag @@ -0,0 +1,8 @@ +#version 330 core + +out vec4 frag_color; + +void main() +{ + frag_color = vec4(1.0); +} diff --git a/pyrender/pyrender/shaders/mesh_depth.vert b/pyrender/pyrender/shaders/mesh_depth.vert new file mode 100644 index 0000000000000000000000000000000000000000..e534c058fb3e7b0efbec090513d55982db68ccaf --- /dev/null +++ b/pyrender/pyrender/shaders/mesh_depth.vert @@ -0,0 +1,13 @@ +#version 330 core +layout(location = 0) in vec3 position; +layout(location = INST_M_LOC) in mat4 inst_m; + +uniform mat4 P; +uniform mat4 V; +uniform mat4 M; + +void main() +{ + mat4 light_matrix = P * V; + gl_Position = light_matrix * M * inst_m * vec4(position, 1.0); +} diff --git a/pyrender/pyrender/shaders/segmentation.frag b/pyrender/pyrender/shaders/segmentation.frag new file mode 100644 index 0000000000000000000000000000000000000000..40deb92cbdef3ec9fd952632624cd5f4b5ce0c84 --- /dev/null +++ b/pyrender/pyrender/shaders/segmentation.frag @@ -0,0 +1,13 @@ +#version 330 core + +uniform vec3 color; +out vec4 frag_color; + +/////////////////////////////////////////////////////////////////////////////// +// MAIN +/////////////////////////////////////////////////////////////////////////////// +void main() +{ + frag_color = vec4(color, 1.0); + //frag_color = vec4(1.0, 0.5, 0.5, 1.0); +} diff --git a/pyrender/pyrender/shaders/segmentation.vert b/pyrender/pyrender/shaders/segmentation.vert new file mode 100644 index 0000000000000000000000000000000000000000..503382599dae3c9415845f35b99d6678cfc7f716 --- /dev/null +++ b/pyrender/pyrender/shaders/segmentation.vert @@ -0,0 +1,14 @@ +#version 330 core +layout(location = 0) in vec3 position; +layout(location = INST_M_LOC) in mat4 inst_m; + +uniform mat4 P; +uniform mat4 V; +uniform mat4 M; + +void main() +{ + mat4 light_matrix = P * V; + gl_Position = light_matrix * M * inst_m * vec4(position, 1.0); +} + diff --git a/pyrender/pyrender/shaders/text.frag b/pyrender/pyrender/shaders/text.frag new file mode 100644 index 0000000000000000000000000000000000000000..486c97dc94ed5e9083ae348bc1e85c5cb26c44dc --- /dev/null +++ b/pyrender/pyrender/shaders/text.frag @@ -0,0 +1,12 @@ +#version 330 core +in vec2 uv; +out vec4 color; + +uniform sampler2D text; +uniform vec4 text_color; + +void main() +{ + vec4 sampled = vec4(1.0, 1.0, 1.0, texture(text, uv).r); + color = text_color * sampled; +} diff --git a/pyrender/pyrender/shaders/text.vert b/pyrender/pyrender/shaders/text.vert new file mode 100644 index 0000000000000000000000000000000000000000..005bc439b3d63522df99e5db2088953eb8defcf4 --- /dev/null +++ b/pyrender/pyrender/shaders/text.vert @@ -0,0 +1,12 @@ +#version 330 core +layout (location = 0) in vec4 vertex; + +out vec2 uv; + +uniform mat4 projection; + +void main() +{ + gl_Position = projection * vec4(vertex.xy, 0.0, 1.0); + uv = vertex.zw; +} diff --git a/pyrender/pyrender/shaders/vertex_normals.frag b/pyrender/pyrender/shaders/vertex_normals.frag new file mode 100644 index 0000000000000000000000000000000000000000..edf5beb7f283dd67e1710bff922555539966cee4 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals.frag @@ -0,0 +1,10 @@ +#version 330 core + +out vec4 frag_color; + +uniform vec4 normal_color; + +void main() +{ + frag_color = normal_color; +} diff --git a/pyrender/pyrender/shaders/vertex_normals.geom b/pyrender/pyrender/shaders/vertex_normals.geom new file mode 100644 index 0000000000000000000000000000000000000000..57f0b0e645e72d41116f5767d66fc37d01ed2714 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals.geom @@ -0,0 +1,74 @@ +#version 330 core + +layout (triangles) in; + +#ifdef FACE_NORMALS + +#ifdef VERTEX_NORMALS + layout (line_strip, max_vertices = 8) out; +#else + layout (line_strip, max_vertices = 2) out; +#endif + +#else + + layout (line_strip, max_vertices = 6) out; + +#endif + +in VS_OUT { + vec3 position; + vec3 normal; + mat4 mvp; +} gs_in[]; + +uniform float normal_magnitude; + +void GenerateVertNormal(int index) +{ + + vec4 p0 = gs_in[index].mvp * vec4(gs_in[index].position, 1.0); + vec4 p1 = gs_in[index].mvp * vec4(normal_magnitude * normalize(gs_in[index].normal) + gs_in[index].position, 1.0); + gl_Position = p0; + EmitVertex(); + gl_Position = p1; + EmitVertex(); + EndPrimitive(); +} + +void GenerateFaceNormal() +{ + vec3 p0 = gs_in[0].position.xyz; + vec3 p1 = gs_in[1].position.xyz; + vec3 p2 = gs_in[2].position.xyz; + + vec3 v0 = p0 - p1; + vec3 v1 = p2 - p1; + + vec3 N = normalize(cross(v1, v0)); + vec3 P = (p0 + p1 + p2) / 3.0; + + vec4 np0 = gs_in[0].mvp * vec4(P, 1.0); + vec4 np1 = gs_in[0].mvp * vec4(normal_magnitude * N + P, 1.0); + + gl_Position = np0; + EmitVertex(); + gl_Position = np1; + EmitVertex(); + EndPrimitive(); +} + +void main() +{ + +#ifdef FACE_NORMALS + GenerateFaceNormal(); +#endif + +#ifdef VERTEX_NORMALS + GenerateVertNormal(0); + GenerateVertNormal(1); + GenerateVertNormal(2); +#endif + +} diff --git a/pyrender/pyrender/shaders/vertex_normals.vert b/pyrender/pyrender/shaders/vertex_normals.vert new file mode 100644 index 0000000000000000000000000000000000000000..be22eed2a0e904bcaf1ac5a4721558e574cddc62 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals.vert @@ -0,0 +1,27 @@ +#version 330 core + +// Inputs +layout(location = 0) in vec3 position; +layout(location = NORMAL_LOC) in vec3 normal; +layout(location = INST_M_LOC) in mat4 inst_m; + +// Output data +out VS_OUT { + vec3 position; + vec3 normal; + mat4 mvp; +} vs_out; + +// Uniform data +uniform mat4 M; +uniform mat4 V; +uniform mat4 P; + +// Render loop +void main() { + vs_out.mvp = P * V * M * inst_m; + vs_out.position = position; + vs_out.normal = normal; + + gl_Position = vec4(position, 1.0); +} diff --git a/pyrender/pyrender/shaders/vertex_normals_pc.geom b/pyrender/pyrender/shaders/vertex_normals_pc.geom new file mode 100644 index 0000000000000000000000000000000000000000..4ea4e7b8542703f64b8d28fd187e425137861fe4 --- /dev/null +++ b/pyrender/pyrender/shaders/vertex_normals_pc.geom @@ -0,0 +1,29 @@ +#version 330 core + +layout (points) in; + +layout (line_strip, max_vertices = 2) out; + +in VS_OUT { + vec3 position; + vec3 normal; + mat4 mvp; +} gs_in[]; + +uniform float normal_magnitude; + +void GenerateVertNormal(int index) +{ + vec4 p0 = gs_in[index].mvp * vec4(gs_in[index].position, 1.0); + vec4 p1 = gs_in[index].mvp * vec4(normal_magnitude * normalize(gs_in[index].normal) + gs_in[index].position, 1.0); + gl_Position = p0; + EmitVertex(); + gl_Position = p1; + EmitVertex(); + EndPrimitive(); +} + +void main() +{ + GenerateVertNormal(0); +} diff --git a/pyrender/pyrender/texture.py b/pyrender/pyrender/texture.py new file mode 100644 index 0000000000000000000000000000000000000000..477759729d7b995a4f276e81d649617d045a066e --- /dev/null +++ b/pyrender/pyrender/texture.py @@ -0,0 +1,259 @@ +"""Textures, conforming to the glTF 2.0 standards as specified in +https://github.com/KhronosGroup/glTF/tree/master/specification/2.0#reference-texture + +Author: Matthew Matl +""" +import numpy as np + +from OpenGL.GL import * + +from .utils import format_texture_source +from .sampler import Sampler + + +class Texture(object): + """A texture and its sampler. + + Parameters + ---------- + name : str, optional + The user-defined name of this object. + sampler : :class:`Sampler` + The sampler used by this texture. + source : (h,w,c) uint8 or (h,w,c) float or :class:`PIL.Image.Image` + The image used by this texture. If None, the texture is created + empty and width and height must be specified. + source_channels : str + Either `D`, `R`, `RG`, `GB`, `RGB`, or `RGBA`. Indicates the + channels to extract from `source`. Any missing channels will be filled + with `1.0`. + width : int, optional + For empty textures, the width of the texture buffer. + height : int, optional + For empty textures, the height of the texture buffer. + tex_type : int + Either GL_TEXTURE_2D or GL_TEXTURE_CUBE. + data_format : int + For now, just GL_FLOAT. + """ + + def __init__(self, + name=None, + sampler=None, + source=None, + source_channels=None, + width=None, + height=None, + tex_type=GL_TEXTURE_2D, + data_format=GL_UNSIGNED_BYTE): + self.source_channels = source_channels + self.name = name + self.sampler = sampler + self.source = source + self.width = width + self.height = height + self.tex_type = tex_type + self.data_format = data_format + + self._texid = None + self._is_transparent = False + + @property + def name(self): + """str : The user-defined name of this object. + """ + return self._name + + @name.setter + def name(self, value): + if value is not None: + value = str(value) + self._name = value + + @property + def sampler(self): + """:class:`Sampler` : The sampler used by this texture. + """ + return self._sampler + + @sampler.setter + def sampler(self, value): + if value is None: + value = Sampler() + self._sampler = value + + @property + def source(self): + """(h,w,c) uint8 or float or :class:`PIL.Image.Image` : The image + used in this texture. + """ + return self._source + + @source.setter + def source(self, value): + if value is None: + self._source = None + else: + self._source = format_texture_source(value, self.source_channels) + self._is_transparent = False + + @property + def source_channels(self): + """str : The channels that were extracted from the original source. + """ + return self._source_channels + + @source_channels.setter + def source_channels(self, value): + self._source_channels = value + + @property + def width(self): + """int : The width of the texture buffer. + """ + return self._width + + @width.setter + def width(self, value): + self._width = value + + @property + def height(self): + """int : The height of the texture buffer. + """ + return self._height + + @height.setter + def height(self, value): + self._height = value + + @property + def tex_type(self): + """int : The type of the texture. + """ + return self._tex_type + + @tex_type.setter + def tex_type(self, value): + self._tex_type = value + + @property + def data_format(self): + """int : The format of the texture data. + """ + return self._data_format + + @data_format.setter + def data_format(self, value): + self._data_format = value + + def is_transparent(self, cutoff=1.0): + """bool : If True, the texture is partially transparent. + """ + if self._is_transparent is None: + self._is_transparent = False + if self.source_channels == 'RGBA' and self.source is not None: + if np.any(self.source[:,:,3] < cutoff): + self._is_transparent = True + return self._is_transparent + + def delete(self): + """Remove this texture from the OpenGL context. + """ + self._unbind() + self._remove_from_context() + + ################## + # OpenGL code + ################## + def _add_to_context(self): + if self._texid is not None: + raise ValueError('Texture already loaded into OpenGL context') + + fmt = GL_DEPTH_COMPONENT + if self.source_channels == 'R': + fmt = GL_RED + elif self.source_channels == 'RG' or self.source_channels == 'GB': + fmt = GL_RG + elif self.source_channels == 'RGB': + fmt = GL_RGB + elif self.source_channels == 'RGBA': + fmt = GL_RGBA + + # Generate the OpenGL texture + self._texid = glGenTextures(1) + glBindTexture(self.tex_type, self._texid) + + # Flip data for OpenGL buffer + data = None + width = self.width + height = self.height + if self.source is not None: + data = np.ascontiguousarray(np.flip(self.source, axis=0).flatten()) + width = self.source.shape[1] + height = self.source.shape[0] + + # Bind texture and generate mipmaps + glTexImage2D( + self.tex_type, 0, fmt, width, height, 0, fmt, + self.data_format, data + ) + if self.source is not None: + glGenerateMipmap(self.tex_type) + + if self.sampler.magFilter is not None: + glTexParameteri( + self.tex_type, GL_TEXTURE_MAG_FILTER, self.sampler.magFilter + ) + else: + if self.source is not None: + glTexParameteri(self.tex_type, GL_TEXTURE_MAG_FILTER, GL_LINEAR) + else: + glTexParameteri(self.tex_type, GL_TEXTURE_MAG_FILTER, GL_NEAREST) + if self.sampler.minFilter is not None: + glTexParameteri( + self.tex_type, GL_TEXTURE_MIN_FILTER, self.sampler.minFilter + ) + else: + if self.source is not None: + glTexParameteri(self.tex_type, GL_TEXTURE_MIN_FILTER, GL_LINEAR_MIPMAP_LINEAR) + else: + glTexParameteri(self.tex_type, GL_TEXTURE_MIN_FILTER, GL_NEAREST) + + glTexParameteri(self.tex_type, GL_TEXTURE_WRAP_S, self.sampler.wrapS) + glTexParameteri(self.tex_type, GL_TEXTURE_WRAP_T, self.sampler.wrapT) + border_color = 255 * np.ones(4).astype(np.uint8) + if self.data_format == GL_FLOAT: + border_color = np.ones(4).astype(np.float32) + glTexParameterfv( + self.tex_type, GL_TEXTURE_BORDER_COLOR, + border_color + ) + + # Unbind texture + glBindTexture(self.tex_type, 0) + + def _remove_from_context(self): + if self._texid is not None: + # TODO OPENGL BUG? + # glDeleteTextures(1, [self._texid]) + glDeleteTextures([self._texid]) + self._texid = None + + def _in_context(self): + return self._texid is not None + + def _bind(self): + # TODO HANDLE INDEXING INTO OTHER UV's + glBindTexture(self.tex_type, self._texid) + + def _unbind(self): + glBindTexture(self.tex_type, 0) + + def _bind_as_depth_attachment(self): + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, + self.tex_type, self._texid, 0) + + def _bind_as_color_attachment(self): + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + self.tex_type, self._texid, 0) diff --git a/pyrender/pyrender/trackball.py b/pyrender/pyrender/trackball.py new file mode 100644 index 0000000000000000000000000000000000000000..3e57a0e82d3f07b80754f575c28a0e05cb73fc50 --- /dev/null +++ b/pyrender/pyrender/trackball.py @@ -0,0 +1,216 @@ +"""Trackball class for 3D manipulation of viewpoints. +""" +import numpy as np + +import trimesh.transformations as transformations + + +class Trackball(object): + """A trackball class for creating camera transforms from mouse movements. + """ + STATE_ROTATE = 0 + STATE_PAN = 1 + STATE_ROLL = 2 + STATE_ZOOM = 3 + + def __init__(self, pose, size, scale, + target=np.array([0.0, 0.0, 0.0])): + """Initialize a trackball with an initial camera-to-world pose + and the given parameters. + + Parameters + ---------- + pose : [4,4] + An initial camera-to-world pose for the trackball. + + size : (float, float) + The width and height of the camera image in pixels. + + scale : float + The diagonal of the scene's bounding box -- + used for ensuring translation motions are sufficiently + fast for differently-sized scenes. + + target : (3,) float + The center of the scene in world coordinates. + The trackball will revolve around this point. + """ + self._size = np.array(size) + self._scale = float(scale) + + self._pose = pose + self._n_pose = pose + + self._target = target + self._n_target = target + + self._state = Trackball.STATE_ROTATE + + @property + def pose(self): + """autolab_core.RigidTransform : The current camera-to-world pose. + """ + return self._n_pose + + def set_state(self, state): + """Set the state of the trackball in order to change the effect of + dragging motions. + + Parameters + ---------- + state : int + One of Trackball.STATE_ROTATE, Trackball.STATE_PAN, + Trackball.STATE_ROLL, and Trackball.STATE_ZOOM. + """ + self._state = state + + def resize(self, size): + """Resize the window. + + Parameters + ---------- + size : (float, float) + The new width and height of the camera image in pixels. + """ + self._size = np.array(size) + + def down(self, point): + """Record an initial mouse press at a given point. + + Parameters + ---------- + point : (2,) int + The x and y pixel coordinates of the mouse press. + """ + self._pdown = np.array(point, dtype=np.float32) + self._pose = self._n_pose + self._target = self._n_target + + def drag(self, point): + """Update the tracball during a drag. + + Parameters + ---------- + point : (2,) int + The current x and y pixel coordinates of the mouse during a drag. + This will compute a movement for the trackball with the relative + motion between this point and the one marked by down(). + """ + point = np.array(point, dtype=np.float32) + dx, dy = point - self._pdown + mindim = 0.3 * np.min(self._size) + + target = self._target + x_axis = self._pose[:3,0].flatten() + y_axis = self._pose[:3,1].flatten() + z_axis = self._pose[:3,2].flatten() + eye = self._pose[:3,3].flatten() + + # Interpret drag as a rotation + if self._state == Trackball.STATE_ROTATE: + x_angle = -dx / mindim + x_rot_mat = transformations.rotation_matrix( + x_angle, y_axis, target + ) + + y_angle = dy / mindim + y_rot_mat = transformations.rotation_matrix( + y_angle, x_axis, target + ) + + self._n_pose = y_rot_mat.dot(x_rot_mat.dot(self._pose)) + + # Interpret drag as a roll about the camera axis + elif self._state == Trackball.STATE_ROLL: + center = self._size / 2.0 + v_init = self._pdown - center + v_curr = point - center + v_init = v_init / np.linalg.norm(v_init) + v_curr = v_curr / np.linalg.norm(v_curr) + + theta = (-np.arctan2(v_curr[1], v_curr[0]) + + np.arctan2(v_init[1], v_init[0])) + + rot_mat = transformations.rotation_matrix(theta, z_axis, target) + + self._n_pose = rot_mat.dot(self._pose) + + # Interpret drag as a camera pan in view plane + elif self._state == Trackball.STATE_PAN: + dx = -dx / (5.0 * mindim) * self._scale + dy = -dy / (5.0 * mindim) * self._scale + + translation = dx * x_axis + dy * y_axis + self._n_target = self._target + translation + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._n_pose = t_tf.dot(self._pose) + + # Interpret drag as a zoom motion + elif self._state == Trackball.STATE_ZOOM: + radius = np.linalg.norm(eye - target) + ratio = 0.0 + if dy > 0: + ratio = np.exp(abs(dy) / (0.5 * self._size[1])) - 1.0 + elif dy < 0: + ratio = 1.0 - np.exp(dy / (0.5 * (self._size[1]))) + translation = -np.sign(dy) * ratio * radius * z_axis + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._n_pose = t_tf.dot(self._pose) + + def scroll(self, clicks): + """Zoom using a mouse scroll wheel motion. + + Parameters + ---------- + clicks : int + The number of clicks. Positive numbers indicate forward wheel + movement. + """ + target = self._target + ratio = 0.90 + + mult = 1.0 + if clicks > 0: + mult = ratio**clicks + elif clicks < 0: + mult = (1.0 / ratio)**abs(clicks) + + z_axis = self._n_pose[:3,2].flatten() + eye = self._n_pose[:3,3].flatten() + radius = np.linalg.norm(eye - target) + translation = (mult * radius - radius) * z_axis + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._n_pose = t_tf.dot(self._n_pose) + + z_axis = self._pose[:3,2].flatten() + eye = self._pose[:3,3].flatten() + radius = np.linalg.norm(eye - target) + translation = (mult * radius - radius) * z_axis + t_tf = np.eye(4) + t_tf[:3,3] = translation + self._pose = t_tf.dot(self._pose) + + def rotate(self, azimuth, axis=None): + """Rotate the trackball about the "Up" axis by azimuth radians. + + Parameters + ---------- + azimuth : float + The number of radians to rotate. + """ + target = self._target + + y_axis = self._n_pose[:3,1].flatten() + if axis is not None: + y_axis = axis + x_rot_mat = transformations.rotation_matrix(azimuth, y_axis, target) + self._n_pose = x_rot_mat.dot(self._n_pose) + + y_axis = self._pose[:3,1].flatten() + if axis is not None: + y_axis = axis + x_rot_mat = transformations.rotation_matrix(azimuth, y_axis, target) + self._pose = x_rot_mat.dot(self._pose) diff --git a/pyrender/pyrender/utils.py b/pyrender/pyrender/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48a11faf991606ad7fb0691582f0bc6f06101a45 --- /dev/null +++ b/pyrender/pyrender/utils.py @@ -0,0 +1,115 @@ +import numpy as np +from PIL import Image + + +def format_color_vector(value, length): + """Format a color vector. + """ + if isinstance(value, int): + value = value / 255.0 + if isinstance(value, float): + value = np.repeat(value, length) + if isinstance(value, list) or isinstance(value, tuple): + value = np.array(value) + if isinstance(value, np.ndarray): + value = value.squeeze() + if np.issubdtype(value.dtype, np.integer): + value = (value / 255.0).astype(np.float32) + if value.ndim != 1: + raise ValueError('Format vector takes only 1-D vectors') + if length > value.shape[0]: + value = np.hstack((value, np.ones(length - value.shape[0]))) + elif length < value.shape[0]: + value = value[:length] + else: + raise ValueError('Invalid vector data type') + + return value.squeeze().astype(np.float32) + + +def format_color_array(value, shape): + """Format an array of colors. + """ + # Convert uint8 to floating + value = np.asanyarray(value) + if np.issubdtype(value.dtype, np.integer): + value = (value / 255.0).astype(np.float32) + + # Match up shapes + if value.ndim == 1: + value = np.tile(value, (shape[0],1)) + if value.shape[1] < shape[1]: + nc = shape[1] - value.shape[1] + value = np.column_stack((value, np.ones((value.shape[0], nc)))) + elif value.shape[1] > shape[1]: + value = value[:,:shape[1]] + return value.astype(np.float32) + + +def format_texture_source(texture, target_channels='RGB'): + """Format a texture as a float32 np array. + """ + + # Pass through None + if texture is None: + return None + + # Convert PIL images into numpy arrays + if isinstance(texture, Image.Image): + if texture.mode == 'P' and target_channels in ('RGB', 'RGBA'): + texture = np.array(texture.convert(target_channels)) + else: + texture = np.array(texture) + + # Format numpy arrays + if isinstance(texture, np.ndarray): + if np.issubdtype(texture.dtype, np.floating): + texture = np.array(texture * 255.0, dtype=np.uint8) + elif np.issubdtype(texture.dtype, np.integer): + texture = texture.astype(np.uint8) + else: + raise TypeError('Invalid type {} for texture'.format( + type(texture) + )) + + # Format array by picking out correct texture channels or padding + if texture.ndim == 2: + texture = texture[:,:,np.newaxis] + if target_channels == 'R': + texture = texture[:,:,0] + texture = texture.squeeze() + elif target_channels == 'RG': + if texture.shape[2] == 1: + texture = np.repeat(texture, 2, axis=2) + else: + texture = texture[:,:,(0,1)] + elif target_channels == 'GB': + if texture.shape[2] == 1: + texture = np.repeat(texture, 2, axis=2) + elif texture.shape[2] > 2: + texture = texture[:,:,(1,2)] + elif target_channels == 'RGB': + if texture.shape[2] == 1: + texture = np.repeat(texture, 3, axis=2) + elif texture.shape[2] == 2: + raise ValueError('Cannot reformat 2-channel texture into RGB') + else: + texture = texture[:,:,(0,1,2)] + elif target_channels == 'RGBA': + if texture.shape[2] == 1: + texture = np.repeat(texture, 4, axis=2) + texture[:,:,3] = 255 + elif texture.shape[2] == 2: + raise ValueError('Cannot reformat 2-channel texture into RGBA') + elif texture.shape[2] == 3: + tx = np.empty((texture.shape[0], texture.shape[1], 4), dtype=np.uint8) + tx[:,:,:3] = texture + tx[:,:,3] = 255 + texture = tx + else: + raise ValueError('Invalid texture channel specification: {}' + .format(target_channels)) + else: + raise TypeError('Invalid type {} for texture'.format(type(texture))) + + return texture diff --git a/pyrender/pyrender/version.py b/pyrender/pyrender/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a33fc87f61f528780e3319a5160769cc84512b1b --- /dev/null +++ b/pyrender/pyrender/version.py @@ -0,0 +1 @@ +__version__ = '0.1.45' diff --git a/pyrender/pyrender/viewer.py b/pyrender/pyrender/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..d2326c38205c6eaddb4f567e3b088329187af258 --- /dev/null +++ b/pyrender/pyrender/viewer.py @@ -0,0 +1,1160 @@ +"""A pyglet-based interactive 3D scene viewer. +""" +import copy +import os +import sys +from threading import Thread, RLock +import time + +import imageio +import numpy as np +import OpenGL +import trimesh + +try: + from Tkinter import Tk, tkFileDialog as filedialog +except Exception: + try: + from tkinter import Tk, filedialog as filedialog + except Exception: + pass + +from .constants import (TARGET_OPEN_GL_MAJOR, TARGET_OPEN_GL_MINOR, + MIN_OPEN_GL_MAJOR, MIN_OPEN_GL_MINOR, + TEXT_PADDING, DEFAULT_SCENE_SCALE, + DEFAULT_Z_FAR, DEFAULT_Z_NEAR, RenderFlags, TextAlign) +from .light import DirectionalLight +from .node import Node +from .camera import PerspectiveCamera, OrthographicCamera, IntrinsicsCamera +from .trackball import Trackball +from .renderer import Renderer +from .mesh import Mesh + +import pyglet +from pyglet import clock +pyglet.options['shadow_window'] = False + + +class Viewer(pyglet.window.Window): + """An interactive viewer for 3D scenes. + + The viewer's camera is separate from the scene's, but will take on + the parameters of the scene's main view camera and start in the same pose. + If the scene does not have a camera, a suitable default will be provided. + + Parameters + ---------- + scene : :class:`Scene` + The scene to visualize. + viewport_size : (2,) int + The width and height of the initial viewing window. + render_flags : dict + A set of flags for rendering the scene. Described in the note below. + viewer_flags : dict + A set of flags for controlling the viewer's behavior. + Described in the note below. + registered_keys : dict + A map from ASCII key characters to tuples containing: + + - A function to be called whenever the key is pressed, + whose first argument will be the viewer itself. + - (Optionally) A list of additional positional arguments + to be passed to the function. + - (Optionally) A dict of keyword arguments to be passed + to the function. + + kwargs : dict + Any keyword arguments left over will be interpreted as belonging to + either the :attr:`.Viewer.render_flags` or :attr:`.Viewer.viewer_flags` + dictionaries. Those flag sets will be updated appropriately. + + Note + ---- + The basic commands for moving about the scene are given as follows: + + - **Rotating about the scene**: Hold the left mouse button and + drag the cursor. + - **Rotating about the view axis**: Hold ``CTRL`` and the left mouse + button and drag the cursor. + - **Panning**: + + - Hold SHIFT, then hold the left mouse button and drag the cursor, or + - Hold the middle mouse button and drag the cursor. + + - **Zooming**: + + - Scroll the mouse wheel, or + - Hold the right mouse button and drag the cursor. + + Other keyboard commands are as follows: + + - ``a``: Toggles rotational animation mode. + - ``c``: Toggles backface culling. + - ``f``: Toggles fullscreen mode. + - ``h``: Toggles shadow rendering. + - ``i``: Toggles axis display mode + (no axes, world axis, mesh axes, all axes). + - ``l``: Toggles lighting mode + (scene lighting, Raymond lighting, or direct lighting). + - ``m``: Toggles face normal visualization. + - ``n``: Toggles vertex normal visualization. + - ``o``: Toggles orthographic mode. + - ``q``: Quits the viewer. + - ``r``: Starts recording a GIF, and pressing again stops recording + and opens a file dialog. + - ``s``: Opens a file dialog to save the current view as an image. + - ``w``: Toggles wireframe mode + (scene default, flip wireframes, all wireframe, or all solid). + - ``z``: Resets the camera to the initial view. + + Note + ---- + The valid keys for ``render_flags`` are as follows: + + - ``flip_wireframe``: `bool`, If `True`, all objects will have their + wireframe modes flipped from what their material indicates. + Defaults to `False`. + - ``all_wireframe``: `bool`, If `True`, all objects will be rendered + in wireframe mode. Defaults to `False`. + - ``all_solid``: `bool`, If `True`, all objects will be rendered in + solid mode. Defaults to `False`. + - ``shadows``: `bool`, If `True`, shadows will be rendered. + Defaults to `False`. + - ``vertex_normals``: `bool`, If `True`, vertex normals will be + rendered as blue lines. Defaults to `False`. + - ``face_normals``: `bool`, If `True`, face normals will be rendered as + blue lines. Defaults to `False`. + - ``cull_faces``: `bool`, If `True`, backfaces will be culled. + Defaults to `True`. + - ``point_size`` : float, The point size in pixels. Defaults to 1px. + + Note + ---- + The valid keys for ``viewer_flags`` are as follows: + + - ``rotate``: `bool`, If `True`, the scene's camera will rotate + about an axis. Defaults to `False`. + - ``rotate_rate``: `float`, The rate of rotation in radians per second. + Defaults to `PI / 3.0`. + - ``rotate_axis``: `(3,) float`, The axis in world coordinates to rotate + about. Defaults to ``[0,0,1]``. + - ``view_center``: `(3,) float`, The position to rotate the scene about. + Defaults to the scene's centroid. + - ``use_raymond_lighting``: `bool`, If `True`, an additional set of three + directional lights that move with the camera will be added to the scene. + Defaults to `False`. + - ``use_direct_lighting``: `bool`, If `True`, an additional directional + light that moves with the camera and points out of it will be added to + the scene. Defaults to `False`. + - ``lighting_intensity``: `float`, The overall intensity of the + viewer's additional lights (when they're in use). Defaults to 3.0. + - ``use_perspective_cam``: `bool`, If `True`, a perspective camera will + be used. Otherwise, an orthographic camera is used. Defaults to `True`. + - ``save_directory``: `str`, A directory to open the file dialogs in. + Defaults to `None`. + - ``window_title``: `str`, A title for the viewer's application window. + Defaults to `"Scene Viewer"`. + - ``refresh_rate``: `float`, A refresh rate for rendering, in Hertz. + Defaults to `30.0`. + - ``fullscreen``: `bool`, Whether to make viewer fullscreen. + Defaults to `False`. + - ``show_world_axis``: `bool`, Whether to show the world axis. + Defaults to `False`. + - ``show_mesh_axes``: `bool`, Whether to show the individual mesh axes. + Defaults to `False`. + - ``caption``: `list of dict`, Text caption(s) to display on the viewer. + Defaults to `None`. + + Note + ---- + Animation can be accomplished by running the viewer with ``run_in_thread`` + enabled. Then, just run a loop in your main thread, updating the scene as + needed. Before updating the scene, be sure to acquire the + :attr:`.Viewer.render_lock`, and release it when your update is done. + """ + + def __init__(self, scene, viewport_size=None, + render_flags=None, viewer_flags=None, + registered_keys=None, run_in_thread=False, + auto_start=True, + **kwargs): + + ####################################################################### + # Save attributes and flags + ####################################################################### + if viewport_size is None: + viewport_size = (640, 480) + self._scene = scene + self._viewport_size = viewport_size + self._render_lock = RLock() + self._is_active = False + self._should_close = False + self._run_in_thread = run_in_thread + self._auto_start = auto_start + + self._default_render_flags = { + 'flip_wireframe': False, + 'all_wireframe': False, + 'all_solid': False, + 'shadows': False, + 'vertex_normals': False, + 'face_normals': False, + 'cull_faces': True, + 'point_size': 1.0, + } + self._default_viewer_flags = { + 'mouse_pressed': False, + 'rotate': False, + 'rotate_rate': np.pi / 3.0, + 'rotate_axis': np.array([0.0, 0.0, 1.0]), + 'view_center': None, + 'record': False, + 'use_raymond_lighting': False, + 'use_direct_lighting': False, + 'lighting_intensity': 3.0, + 'use_perspective_cam': True, + 'save_directory': None, + 'window_title': 'Scene Viewer', + 'refresh_rate': 30.0, + 'fullscreen': False, + 'show_world_axis': False, + 'show_mesh_axes': False, + 'caption': None + } + self._render_flags = self._default_render_flags.copy() + self._viewer_flags = self._default_viewer_flags.copy() + self._viewer_flags['rotate_axis'] = ( + self._default_viewer_flags['rotate_axis'].copy() + ) + + if render_flags is not None: + self._render_flags.update(render_flags) + if viewer_flags is not None: + self._viewer_flags.update(viewer_flags) + + for key in kwargs: + if key in self.render_flags: + self._render_flags[key] = kwargs[key] + elif key in self.viewer_flags: + self._viewer_flags[key] = kwargs[key] + + # TODO MAC OS BUG FOR SHADOWS + if sys.platform == 'darwin': + self._render_flags['shadows'] = False + + self._registered_keys = {} + if registered_keys is not None: + self._registered_keys = { + ord(k.lower()): registered_keys[k] for k in registered_keys + } + + ####################################################################### + # Save internal settings + ####################################################################### + + # Set up caption stuff + self._message_text = None + self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags['refresh_rate'] + self._message_opac = 1.0 + self._ticks_till_fade + + # Set up raymond lights and direct lights + self._raymond_lights = self._create_raymond_lights() + self._direct_light = self._create_direct_light() + + # Set up axes + self._axes = {} + self._axis_mesh = Mesh.from_trimesh( + trimesh.creation.axis(origin_size=0.1, axis_radius=0.05, + axis_length=1.0), smooth=False) + if self.viewer_flags['show_world_axis']: + self._set_axes(world=self.viewer_flags['show_world_axis'], + mesh=self.viewer_flags['show_mesh_axes']) + + ####################################################################### + # Set up camera node + ####################################################################### + self._camera_node = None + self._prior_main_camera_node = None + self._default_camera_pose = None + self._default_persp_cam = None + self._default_orth_cam = None + self._trackball = None + self._saved_frames = [] + + # Extract main camera from scene and set up our mirrored copy + znear = None + zfar = None + if scene.main_camera_node is not None: + n = scene.main_camera_node + camera = copy.copy(n.camera) + if isinstance(camera, (PerspectiveCamera, IntrinsicsCamera)): + self._default_persp_cam = camera + znear = camera.znear + zfar = camera.zfar + elif isinstance(camera, OrthographicCamera): + self._default_orth_cam = camera + znear = camera.znear + zfar = camera.zfar + self._default_camera_pose = scene.get_pose(scene.main_camera_node) + self._prior_main_camera_node = n + + # Set defaults as needed + if zfar is None: + zfar = max(scene.scale * 10.0, DEFAULT_Z_FAR) + if znear is None or znear == 0: + if scene.scale == 0: + znear = DEFAULT_Z_NEAR + else: + znear = min(scene.scale / 10.0, DEFAULT_Z_NEAR) + + if self._default_persp_cam is None: + self._default_persp_cam = PerspectiveCamera( + yfov=np.pi / 3.0, znear=znear, zfar=zfar + ) + if self._default_orth_cam is None: + xmag = ymag = scene.scale + if scene.scale == 0: + xmag = ymag = 1.0 + self._default_orth_cam = OrthographicCamera( + xmag=xmag, ymag=ymag, + znear=znear, + zfar=zfar + ) + if self._default_camera_pose is None: + self._default_camera_pose = self._compute_initial_camera_pose() + + # Pick camera + if self.viewer_flags['use_perspective_cam']: + camera = self._default_persp_cam + else: + camera = self._default_orth_cam + + self._camera_node = Node( + matrix=self._default_camera_pose, camera=camera + ) + scene.add_node(self._camera_node) + scene.main_camera_node = self._camera_node + self._reset_view() + + ####################################################################### + # Initialize OpenGL context and renderer + ####################################################################### + self._renderer = Renderer( + self._viewport_size[0], self._viewport_size[1], + self.render_flags['point_size'] + ) + self._is_active = True + + if self.run_in_thread: + self._thread = Thread(target=self._init_and_start_app) + self._thread.start() + else: + if auto_start: + self._init_and_start_app() + + def start(self): + self._init_and_start_app() + + @property + def scene(self): + """:class:`.Scene` : The scene being visualized. + """ + return self._scene + + @property + def viewport_size(self): + """(2,) int : The width and height of the viewing window. + """ + return self._viewport_size + + @property + def render_lock(self): + """:class:`threading.RLock` : If acquired, prevents the viewer from + rendering until released. + + Run :meth:`.Viewer.render_lock.acquire` before making updates to + the scene in a different thread, and run + :meth:`.Viewer.render_lock.release` once you're done to let the viewer + continue. + """ + return self._render_lock + + @property + def is_active(self): + """bool : `True` if the viewer is active, or `False` if it has + been closed. + """ + return self._is_active + + @property + def run_in_thread(self): + """bool : Whether the viewer was run in a separate thread. + """ + return self._run_in_thread + + @property + def render_flags(self): + """dict : Flags for controlling the renderer's behavior. + + - ``flip_wireframe``: `bool`, If `True`, all objects will have their + wireframe modes flipped from what their material indicates. + Defaults to `False`. + - ``all_wireframe``: `bool`, If `True`, all objects will be rendered + in wireframe mode. Defaults to `False`. + - ``all_solid``: `bool`, If `True`, all objects will be rendered in + solid mode. Defaults to `False`. + - ``shadows``: `bool`, If `True`, shadows will be rendered. + Defaults to `False`. + - ``vertex_normals``: `bool`, If `True`, vertex normals will be + rendered as blue lines. Defaults to `False`. + - ``face_normals``: `bool`, If `True`, face normals will be rendered as + blue lines. Defaults to `False`. + - ``cull_faces``: `bool`, If `True`, backfaces will be culled. + Defaults to `True`. + - ``point_size`` : float, The point size in pixels. Defaults to 1px. + + """ + return self._render_flags + + @render_flags.setter + def render_flags(self, value): + self._render_flags = value + + @property + def viewer_flags(self): + """dict : Flags for controlling the viewer's behavior. + + The valid keys for ``viewer_flags`` are as follows: + + - ``rotate``: `bool`, If `True`, the scene's camera will rotate + about an axis. Defaults to `False`. + - ``rotate_rate``: `float`, The rate of rotation in radians per second. + Defaults to `PI / 3.0`. + - ``rotate_axis``: `(3,) float`, The axis in world coordinates to + rotate about. Defaults to ``[0,0,1]``. + - ``view_center``: `(3,) float`, The position to rotate the scene + about. Defaults to the scene's centroid. + - ``use_raymond_lighting``: `bool`, If `True`, an additional set of + three directional lights that move with the camera will be added to + the scene. Defaults to `False`. + - ``use_direct_lighting``: `bool`, If `True`, an additional directional + light that moves with the camera and points out of it will be + added to the scene. Defaults to `False`. + - ``lighting_intensity``: `float`, The overall intensity of the + viewer's additional lights (when they're in use). Defaults to 3.0. + - ``use_perspective_cam``: `bool`, If `True`, a perspective camera will + be used. Otherwise, an orthographic camera is used. Defaults to + `True`. + - ``save_directory``: `str`, A directory to open the file dialogs in. + Defaults to `None`. + - ``window_title``: `str`, A title for the viewer's application window. + Defaults to `"Scene Viewer"`. + - ``refresh_rate``: `float`, A refresh rate for rendering, in Hertz. + Defaults to `30.0`. + - ``fullscreen``: `bool`, Whether to make viewer fullscreen. + Defaults to `False`. + - ``show_world_axis``: `bool`, Whether to show the world axis. + Defaults to `False`. + - ``show_mesh_axes``: `bool`, Whether to show the individual mesh axes. + Defaults to `False`. + - ``caption``: `list of dict`, Text caption(s) to display on + the viewer. Defaults to `None`. + + """ + return self._viewer_flags + + @viewer_flags.setter + def viewer_flags(self, value): + self._viewer_flags = value + + @property + def registered_keys(self): + """dict : Map from ASCII key character to a handler function. + + This is a map from ASCII key characters to tuples containing: + + - A function to be called whenever the key is pressed, + whose first argument will be the viewer itself. + - (Optionally) A list of additional positional arguments + to be passed to the function. + - (Optionally) A dict of keyword arguments to be passed + to the function. + + """ + return self._registered_keys + + @registered_keys.setter + def registered_keys(self, value): + self._registered_keys = value + + def close_external(self): + """Close the viewer from another thread. + + This function will wait for the actual close, so you immediately + manipulate the scene afterwards. + """ + self._should_close = True + while self.is_active: + time.sleep(1.0 / self.viewer_flags['refresh_rate']) + + def save_gif(self, filename=None): + """Save the stored GIF frames to a file. + + To use this asynchronously, run the viewer with the ``record`` + flag and the ``run_in_thread`` flags set. + Kill the viewer after your desired time with + :meth:`.Viewer.close_external`, and then call :meth:`.Viewer.save_gif`. + + Parameters + ---------- + filename : str + The file to save the GIF to. If not specified, + a file dialog will be opened to ask the user where + to save the GIF file. + """ + if filename is None: + filename = self._get_save_filename(['gif', 'all']) + if filename is not None: + self.viewer_flags['save_directory'] = os.path.dirname(filename) + imageio.mimwrite(filename, self._saved_frames, + fps=self.viewer_flags['refresh_rate'], + palettesize=128, subrectangles=True) + self._saved_frames = [] + + def on_close(self): + """Exit the event loop when the window is closed. + """ + # Remove our camera and restore the prior one + if self._camera_node is not None: + self.scene.remove_node(self._camera_node) + if self._prior_main_camera_node is not None: + self.scene.main_camera_node = self._prior_main_camera_node + + # Delete any lighting nodes that we've attached + if self.viewer_flags['use_raymond_lighting']: + for n in self._raymond_lights: + if self.scene.has_node(n): + self.scene.remove_node(n) + if self.viewer_flags['use_direct_lighting']: + if self.scene.has_node(self._direct_light): + self.scene.remove_node(self._direct_light) + + # Delete any axis nodes that we've attached + self._remove_axes() + + # Delete renderer + if self._renderer is not None: + self._renderer.delete() + self._renderer = None + + # Force clean-up of OpenGL context data + try: + OpenGL.contextdata.cleanupContext() + self.close() + except Exception: + pass + finally: + self._is_active = False + super(Viewer, self).on_close() + pyglet.app.exit() + + def on_draw(self): + """Redraw the scene into the viewing window. + """ + if self._renderer is None: + return + + if self.run_in_thread or not self._auto_start: + self.render_lock.acquire() + + # Make OpenGL context current + self.switch_to() + + # Render the scene + self.clear() + self._render() + + if self._message_text is not None: + self._renderer.render_text( + self._message_text, + self.viewport_size[0] - TEXT_PADDING, + TEXT_PADDING, + font_pt=20, + color=np.array([0.1, 0.7, 0.2, + np.clip(self._message_opac, 0.0, 1.0)]), + align=TextAlign.BOTTOM_RIGHT + ) + + if self.viewer_flags['caption'] is not None: + for caption in self.viewer_flags['caption']: + xpos, ypos = self._location_to_x_y(caption['location']) + self._renderer.render_text( + caption['text'], + xpos, + ypos, + font_name=caption['font_name'], + font_pt=caption['font_pt'], + color=caption['color'], + scale=caption['scale'], + align=caption['location'] + ) + + if self.run_in_thread or not self._auto_start: + self.render_lock.release() + + def on_resize(self, width, height): + """Resize the camera and trackball when the window is resized. + """ + if self._renderer is None: + return + + self._viewport_size = (width, height) + self._trackball.resize(self._viewport_size) + self._renderer.viewport_width = self._viewport_size[0] + self._renderer.viewport_height = self._viewport_size[1] + self.on_draw() + + def on_mouse_press(self, x, y, buttons, modifiers): + """Record an initial mouse press. + """ + self._trackball.set_state(Trackball.STATE_ROTATE) + if (buttons == pyglet.window.mouse.LEFT): + ctrl = (modifiers & pyglet.window.key.MOD_CTRL) + shift = (modifiers & pyglet.window.key.MOD_SHIFT) + if (ctrl and shift): + self._trackball.set_state(Trackball.STATE_ZOOM) + elif ctrl: + self._trackball.set_state(Trackball.STATE_ROLL) + elif shift: + self._trackball.set_state(Trackball.STATE_PAN) + elif (buttons == pyglet.window.mouse.MIDDLE): + self._trackball.set_state(Trackball.STATE_PAN) + elif (buttons == pyglet.window.mouse.RIGHT): + self._trackball.set_state(Trackball.STATE_ZOOM) + + self._trackball.down(np.array([x, y])) + + # Stop animating while using the mouse + self.viewer_flags['mouse_pressed'] = True + + def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): + """Record a mouse drag. + """ + self._trackball.drag(np.array([x, y])) + + def on_mouse_release(self, x, y, button, modifiers): + """Record a mouse release. + """ + self.viewer_flags['mouse_pressed'] = False + + def on_mouse_scroll(self, x, y, dx, dy): + """Record a mouse scroll. + """ + if self.viewer_flags['use_perspective_cam']: + self._trackball.scroll(dy) + else: + spfc = 0.95 + spbc = 1.0 / 0.95 + sf = 1.0 + if dy > 0: + sf = spfc * dy + elif dy < 0: + sf = - spbc * dy + + c = self._camera_node.camera + xmag = max(c.xmag * sf, 1e-8) + ymag = max(c.ymag * sf, 1e-8 * c.ymag / c.xmag) + c.xmag = xmag + c.ymag = ymag + + def on_key_press(self, symbol, modifiers): + """Record a key press. + """ + # First, check for registered key callbacks + if symbol in self.registered_keys: + tup = self.registered_keys[symbol] + callback = None + args = [] + kwargs = {} + if not isinstance(tup, (list, tuple, np.ndarray)): + callback = tup + else: + callback = tup[0] + if len(tup) == 2: + args = tup[1] + if len(tup) == 3: + kwargs = tup[2] + callback(self, *args, **kwargs) + return + + # Otherwise, use default key functions + + # A causes the frame to rotate + self._message_text = None + if symbol == pyglet.window.key.A: + self.viewer_flags['rotate'] = not self.viewer_flags['rotate'] + if self.viewer_flags['rotate']: + self._message_text = 'Rotation On' + else: + self._message_text = 'Rotation Off' + + # C toggles backface culling + elif symbol == pyglet.window.key.C: + self.render_flags['cull_faces'] = ( + not self.render_flags['cull_faces'] + ) + if self.render_flags['cull_faces']: + self._message_text = 'Cull Faces On' + else: + self._message_text = 'Cull Faces Off' + + # F toggles face normals + elif symbol == pyglet.window.key.F: + self.viewer_flags['fullscreen'] = ( + not self.viewer_flags['fullscreen'] + ) + self.set_fullscreen(self.viewer_flags['fullscreen']) + self.activate() + if self.viewer_flags['fullscreen']: + self._message_text = 'Fullscreen On' + else: + self._message_text = 'Fullscreen Off' + + # S toggles shadows + elif symbol == pyglet.window.key.H and sys.platform != 'darwin': + self.render_flags['shadows'] = not self.render_flags['shadows'] + if self.render_flags['shadows']: + self._message_text = 'Shadows On' + else: + self._message_text = 'Shadows Off' + + elif symbol == pyglet.window.key.I: + if (self.viewer_flags['show_world_axis'] and not + self.viewer_flags['show_mesh_axes']): + self.viewer_flags['show_world_axis'] = False + self.viewer_flags['show_mesh_axes'] = True + self._set_axes(False, True) + self._message_text = 'Mesh Axes On' + elif (not self.viewer_flags['show_world_axis'] and + self.viewer_flags['show_mesh_axes']): + self.viewer_flags['show_world_axis'] = True + self.viewer_flags['show_mesh_axes'] = True + self._set_axes(True, True) + self._message_text = 'All Axes On' + elif (self.viewer_flags['show_world_axis'] and + self.viewer_flags['show_mesh_axes']): + self.viewer_flags['show_world_axis'] = False + self.viewer_flags['show_mesh_axes'] = False + self._set_axes(False, False) + self._message_text = 'All Axes Off' + else: + self.viewer_flags['show_world_axis'] = True + self.viewer_flags['show_mesh_axes'] = False + self._set_axes(True, False) + self._message_text = 'World Axis On' + + # L toggles the lighting mode + elif symbol == pyglet.window.key.L: + if self.viewer_flags['use_raymond_lighting']: + self.viewer_flags['use_raymond_lighting'] = False + self.viewer_flags['use_direct_lighting'] = True + self._message_text = 'Direct Lighting' + elif self.viewer_flags['use_direct_lighting']: + self.viewer_flags['use_raymond_lighting'] = False + self.viewer_flags['use_direct_lighting'] = False + self._message_text = 'Default Lighting' + else: + self.viewer_flags['use_raymond_lighting'] = True + self.viewer_flags['use_direct_lighting'] = False + self._message_text = 'Raymond Lighting' + + # M toggles face normals + elif symbol == pyglet.window.key.M: + self.render_flags['face_normals'] = ( + not self.render_flags['face_normals'] + ) + if self.render_flags['face_normals']: + self._message_text = 'Face Normals On' + else: + self._message_text = 'Face Normals Off' + + # N toggles vertex normals + elif symbol == pyglet.window.key.N: + self.render_flags['vertex_normals'] = ( + not self.render_flags['vertex_normals'] + ) + if self.render_flags['vertex_normals']: + self._message_text = 'Vert Normals On' + else: + self._message_text = 'Vert Normals Off' + + # O toggles orthographic camera mode + elif symbol == pyglet.window.key.O: + self.viewer_flags['use_perspective_cam'] = ( + not self.viewer_flags['use_perspective_cam'] + ) + if self.viewer_flags['use_perspective_cam']: + camera = self._default_persp_cam + self._message_text = 'Perspective View' + else: + camera = self._default_orth_cam + self._message_text = 'Orthographic View' + + cam_pose = self._camera_node.matrix.copy() + cam_node = Node(matrix=cam_pose, camera=camera) + self.scene.remove_node(self._camera_node) + self.scene.add_node(cam_node) + self.scene.main_camera_node = cam_node + self._camera_node = cam_node + + # Q quits the viewer + elif symbol == pyglet.window.key.Q: + self.on_close() + + # R starts recording frames + elif symbol == pyglet.window.key.R: + if self.viewer_flags['record']: + self.save_gif() + self.set_caption(self.viewer_flags['window_title']) + else: + self.set_caption( + '{} (RECORDING)'.format(self.viewer_flags['window_title']) + ) + self.viewer_flags['record'] = not self.viewer_flags['record'] + + # S saves the current frame as an image + elif symbol == pyglet.window.key.S: + self._save_image() + + # W toggles through wireframe modes + elif symbol == pyglet.window.key.W: + if self.render_flags['flip_wireframe']: + self.render_flags['flip_wireframe'] = False + self.render_flags['all_wireframe'] = True + self.render_flags['all_solid'] = False + self._message_text = 'All Wireframe' + elif self.render_flags['all_wireframe']: + self.render_flags['flip_wireframe'] = False + self.render_flags['all_wireframe'] = False + self.render_flags['all_solid'] = True + self._message_text = 'All Solid' + elif self.render_flags['all_solid']: + self.render_flags['flip_wireframe'] = False + self.render_flags['all_wireframe'] = False + self.render_flags['all_solid'] = False + self._message_text = 'Default Wireframe' + else: + self.render_flags['flip_wireframe'] = True + self.render_flags['all_wireframe'] = False + self.render_flags['all_solid'] = False + self._message_text = 'Flip Wireframe' + + # Z resets the camera viewpoint + elif symbol == pyglet.window.key.Z: + self._reset_view() + + if self._message_text is not None: + self._message_opac = 1.0 + self._ticks_till_fade + + @staticmethod + def _time_event(dt, self): + """The timer callback. + """ + # Don't run old dead events after we've already closed + if not self._is_active: + return + + if self.viewer_flags['record']: + self._record() + if (self.viewer_flags['rotate'] and not + self.viewer_flags['mouse_pressed']): + self._rotate() + + # Manage message opacity + if self._message_text is not None: + if self._message_opac > 1.0: + self._message_opac -= 1.0 + else: + self._message_opac *= 0.90 + if self._message_opac < 0.05: + self._message_opac = 1.0 + self._ticks_till_fade + self._message_text = None + + if self._should_close: + self.on_close() + else: + self.on_draw() + + def _reset_view(self): + """Reset the view to a good initial state. + + The view is initially along the positive x-axis at a + sufficient distance from the scene. + """ + scale = self.scene.scale + if scale == 0.0: + scale = DEFAULT_SCENE_SCALE + centroid = self.scene.centroid + + if self.viewer_flags['view_center'] is not None: + centroid = self.viewer_flags['view_center'] + + self._camera_node.matrix = self._default_camera_pose + self._trackball = Trackball( + self._default_camera_pose, self.viewport_size, scale, centroid + ) + + def _get_save_filename(self, file_exts): + file_types = { + 'png': ('png files', '*.png'), + 'jpg': ('jpeg files', '*.jpg'), + 'gif': ('gif files', '*.gif'), + 'all': ('all files', '*'), + } + filetypes = [file_types[x] for x in file_exts] + try: + root = Tk() + save_dir = self.viewer_flags['save_directory'] + if save_dir is None: + save_dir = os.getcwd() + filename = filedialog.asksaveasfilename( + initialdir=save_dir, title='Select file save location', + filetypes=filetypes + ) + except Exception: + return None + + root.destroy() + if filename == (): + return None + return filename + + def _save_image(self): + filename = self._get_save_filename(['png', 'jpg', 'gif', 'all']) + if filename is not None: + self.viewer_flags['save_directory'] = os.path.dirname(filename) + imageio.imwrite(filename, self._renderer.read_color_buf()) + + def _record(self): + """Save another frame for the GIF. + """ + data = self._renderer.read_color_buf() + if not np.all(data == 0.0): + self._saved_frames.append(data) + + def _rotate(self): + """Animate the scene by rotating the camera. + """ + az = (self.viewer_flags['rotate_rate'] / + self.viewer_flags['refresh_rate']) + self._trackball.rotate(az, self.viewer_flags['rotate_axis']) + + def _render(self): + """Render the scene into the framebuffer and flip. + """ + scene = self.scene + self._camera_node.matrix = self._trackball.pose.copy() + + # Set lighting + vli = self.viewer_flags['lighting_intensity'] + if self.viewer_flags['use_raymond_lighting']: + for n in self._raymond_lights: + n.light.intensity = vli / 3.0 + if not self.scene.has_node(n): + scene.add_node(n, parent_node=self._camera_node) + else: + self._direct_light.light.intensity = vli + for n in self._raymond_lights: + if self.scene.has_node(n): + self.scene.remove_node(n) + + if self.viewer_flags['use_direct_lighting']: + if not self.scene.has_node(self._direct_light): + scene.add_node( + self._direct_light, parent_node=self._camera_node + ) + elif self.scene.has_node(self._direct_light): + self.scene.remove_node(self._direct_light) + + flags = RenderFlags.NONE + if self.render_flags['flip_wireframe']: + flags |= RenderFlags.FLIP_WIREFRAME + elif self.render_flags['all_wireframe']: + flags |= RenderFlags.ALL_WIREFRAME + elif self.render_flags['all_solid']: + flags |= RenderFlags.ALL_SOLID + + if self.render_flags['shadows']: + flags |= RenderFlags.SHADOWS_DIRECTIONAL | RenderFlags.SHADOWS_SPOT + if self.render_flags['vertex_normals']: + flags |= RenderFlags.VERTEX_NORMALS + if self.render_flags['face_normals']: + flags |= RenderFlags.FACE_NORMALS + if not self.render_flags['cull_faces']: + flags |= RenderFlags.SKIP_CULL_FACES + + self._renderer.render(self.scene, flags) + + def _init_and_start_app(self): + # Try multiple configs starting with target OpenGL version + # and multisampling and removing these options if exception + # Note: multisampling not available on all hardware + from pyglet.gl import Config + confs = [Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + Config(depth_size=24, + double_buffer=True, + major_version=TARGET_OPEN_GL_MAJOR, + minor_version=TARGET_OPEN_GL_MINOR), + Config(sample_buffers=1, samples=4, + depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR), + Config(depth_size=24, + double_buffer=True, + major_version=MIN_OPEN_GL_MAJOR, + minor_version=MIN_OPEN_GL_MINOR)] + for conf in confs: + try: + super(Viewer, self).__init__(config=conf, resizable=True, + width=self._viewport_size[0], + height=self._viewport_size[1]) + break + except pyglet.window.NoSuchConfigException: + pass + + if not self.context: + raise ValueError('Unable to initialize an OpenGL 3+ context') + clock.schedule_interval( + Viewer._time_event, 1.0 / self.viewer_flags['refresh_rate'], self + ) + self.switch_to() + self.set_caption(self.viewer_flags['window_title']) + pyglet.app.run() + + def _compute_initial_camera_pose(self): + centroid = self.scene.centroid + if self.viewer_flags['view_center'] is not None: + centroid = self.viewer_flags['view_center'] + scale = self.scene.scale + if scale == 0.0: + scale = DEFAULT_SCENE_SCALE + + s2 = 1.0 / np.sqrt(2.0) + cp = np.eye(4) + cp[:3,:3] = np.array([ + [0.0, -s2, s2], + [1.0, 0.0, 0.0], + [0.0, s2, s2] + ]) + hfov = np.pi / 6.0 + dist = scale / (2.0 * np.tan(hfov)) + cp[:3,3] = dist * np.array([1.0, 0.0, 1.0]) + centroid + + return cp + + def _create_raymond_lights(self): + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3,:3] = np.c_[x,y,z] + nodes.append(Node( + light=DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + + def _create_direct_light(self): + light = DirectionalLight(color=np.ones(3), intensity=1.0) + n = Node(light=light, matrix=np.eye(4)) + return n + + def _set_axes(self, world, mesh): + scale = self.scene.scale + if world: + if 'scene' not in self._axes: + n = Node(mesh=self._axis_mesh, scale=np.ones(3) * scale * 0.3) + self.scene.add_node(n) + self._axes['scene'] = n + else: + if 'scene' in self._axes: + self.scene.remove_node(self._axes['scene']) + self._axes.pop('scene') + + if mesh: + old_nodes = [] + existing_axes = set([self._axes[k] for k in self._axes]) + for node in self.scene.mesh_nodes: + if node not in existing_axes: + old_nodes.append(node) + + for node in old_nodes: + if node in self._axes: + continue + n = Node( + mesh=self._axis_mesh, + scale=np.ones(3) * node.mesh.scale * 0.5 + ) + self.scene.add_node(n, parent_node=node) + self._axes[node] = n + else: + to_remove = set() + for main_node in self._axes: + if main_node in self.scene.mesh_nodes: + self.scene.remove_node(self._axes[main_node]) + to_remove.add(main_node) + for main_node in to_remove: + self._axes.pop(main_node) + + def _remove_axes(self): + for main_node in self._axes: + axis_node = self._axes[main_node] + self.scene.remove_node(axis_node) + self._axes = {} + + def _location_to_x_y(self, location): + if location == TextAlign.CENTER: + return (self.viewport_size[0] / 2.0, self.viewport_size[1] / 2.0) + elif location == TextAlign.CENTER_LEFT: + return (TEXT_PADDING, self.viewport_size[1] / 2.0) + elif location == TextAlign.CENTER_RIGHT: + return (self.viewport_size[0] - TEXT_PADDING, + self.viewport_size[1] / 2.0) + elif location == TextAlign.BOTTOM_LEFT: + return (TEXT_PADDING, TEXT_PADDING) + elif location == TextAlign.BOTTOM_RIGHT: + return (self.viewport_size[0] - TEXT_PADDING, TEXT_PADDING) + elif location == TextAlign.BOTTOM_CENTER: + return (self.viewport_size[0] / 2.0, TEXT_PADDING) + elif location == TextAlign.TOP_LEFT: + return (TEXT_PADDING, self.viewport_size[1] - TEXT_PADDING) + elif location == TextAlign.TOP_RIGHT: + return (self.viewport_size[0] - TEXT_PADDING, + self.viewport_size[1] - TEXT_PADDING) + elif location == TextAlign.TOP_CENTER: + return (self.viewport_size[0] / 2.0, + self.viewport_size[1] - TEXT_PADDING) + + +__all__ = ['Viewer'] diff --git a/pyrender/requirements.txt b/pyrender/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c40b74256f0dc6697754bb8609f69a39d51beba --- /dev/null +++ b/pyrender/requirements.txt @@ -0,0 +1,14 @@ +freetype-py +imageio +networkx +numpy +Pillow +pyglet==1.4.0a1 +PyOpenGL +PyOpenGL_accelerate +six +trimesh +sphinx +sphinx_rtd_theme +sphinx-automodapi + diff --git a/pyrender/setup.py b/pyrender/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b5ba0da2b0f17b759e5556597981096a80bda8 --- /dev/null +++ b/pyrender/setup.py @@ -0,0 +1,76 @@ +""" +Setup of pyrender Python codebase. + +Author: Matthew Matl +""" +import sys +from setuptools import setup + +# load __version__ +exec(open('pyrender/version.py').read()) + +def get_imageio_dep(): + if sys.version[0] == "2": + return 'imageio<=2.6.1' + return 'imageio' + +requirements = [ + 'freetype-py', # For font loading + get_imageio_dep(), # For Image I/O + 'networkx', # For the scene graph + 'numpy', # Numpy + 'Pillow', # For Trimesh texture conversions + 'pyglet>=1.4.10', # For the pyglet viewer + 'PyOpenGL~=3.1.0', # For OpenGL +# 'PyOpenGL_accelerate~=3.1.0', # For OpenGL + 'scipy', # Because of trimesh missing dep + 'six', # For Python 2/3 interop + 'trimesh', # For meshes +] + +dev_requirements = [ + 'flake8', # Code formatting checker + 'pre-commit', # Pre-commit hooks + 'pytest', # Code testing + 'pytest-cov', # Coverage testing + 'tox', # Automatic virtualenv testing +] + +docs_requirements = [ + 'sphinx', # General doc library + 'sphinx_rtd_theme', # RTD theme for sphinx + 'sphinx-automodapi' # For generating nice tables +] + + +setup( + name = 'pyrender', + version=__version__, + description='Easy-to-use Python renderer for 3D visualization', + long_description='A simple implementation of Physically-Based Rendering ' + '(PBR) in Python. Compliant with the glTF 2.0 standard.', + author='Matthew Matl', + author_email='matthewcmatl@gmail.com', + license='MIT License', + url = 'https://github.com/mmatl/pyrender', + classifiers = [ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: MIT License', + 'Operating System :: POSIX :: Linux', + 'Operating System :: MacOS :: MacOS X', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Natural Language :: English', + 'Topic :: Scientific/Engineering' + ], + keywords = 'rendering graphics opengl 3d visualization pbr gltf', + packages = ['pyrender', 'pyrender.platforms'], + setup_requires = requirements, + install_requires = requirements, + extras_require={ + 'dev': dev_requirements, + 'docs': docs_requirements, + }, + include_package_data=True +) diff --git a/pyrender/tests/__init__.py b/pyrender/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/conftest.py b/pyrender/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/pytest.ini b/pyrender/tests/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/unit/__init__.py b/pyrender/tests/unit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyrender/tests/unit/test_cameras.py b/pyrender/tests/unit/test_cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..7544ad8f8e3ee55236fd2e32dbc12065153cbe5b --- /dev/null +++ b/pyrender/tests/unit/test_cameras.py @@ -0,0 +1,164 @@ +import numpy as np +import pytest + +from pyrender import PerspectiveCamera, OrthographicCamera + + +def test_perspective_camera(): + + # Set up constants + znear = 0.05 + zfar = 100 + yfov = np.pi / 3.0 + width = 1000.0 + height = 500.0 + aspectRatio = 640.0 / 480.0 + + # Test basics + with pytest.raises(TypeError): + p = PerspectiveCamera() + + p = PerspectiveCamera(yfov=yfov) + assert p.yfov == yfov + assert p.znear == 0.05 + assert p.zfar is None + assert p.aspectRatio is None + p.name = 'asdf' + p.name = None + + with pytest.raises(ValueError): + p.yfov = 0.0 + + with pytest.raises(ValueError): + p.yfov = -1.0 + + with pytest.raises(ValueError): + p.znear = -1.0 + + p.znear = 0.0 + p.znear = 0.05 + p.zfar = 100.0 + assert p.zfar == 100.0 + + with pytest.raises(ValueError): + p.zfar = 0.03 + + with pytest.raises(ValueError): + p.zfar = 0.05 + + p.aspectRatio = 10.0 + assert p.aspectRatio == 10.0 + + with pytest.raises(ValueError): + p.aspectRatio = 0.0 + + with pytest.raises(ValueError): + p.aspectRatio = -1.0 + + # Test matrix getting/setting + + # NF + p.znear = 0.05 + p.zfar = 100 + p.aspectRatio = None + + with pytest.raises(ValueError): + p.get_projection_matrix() + + assert np.allclose( + p.get_projection_matrix(width, height), + np.array([ + [1.0 / (width / height * np.tan(yfov / 2.0)), 0.0, 0.0, 0.0], + [0.0, 1.0 / np.tan(yfov / 2.0), 0.0, 0.0], + [0.0, 0.0, (zfar + znear) / (znear - zfar), + (2 * zfar * znear) / (znear - zfar)], + [0.0, 0.0, -1.0, 0.0] + ]) + ) + + # NFA + p.aspectRatio = aspectRatio + assert np.allclose( + p.get_projection_matrix(width, height), + np.array([ + [1.0 / (aspectRatio * np.tan(yfov / 2.0)), 0.0, 0.0, 0.0], + [0.0, 1.0 / np.tan(yfov / 2.0), 0.0, 0.0], + [0.0, 0.0, (zfar + znear) / (znear - zfar), + (2 * zfar * znear) / (znear - zfar)], + [0.0, 0.0, -1.0, 0.0] + ]) + ) + assert np.allclose( + p.get_projection_matrix(), p.get_projection_matrix(width, height) + ) + + # N + p.zfar = None + p.aspectRatio = None + assert np.allclose( + p.get_projection_matrix(width, height), + np.array([ + [1.0 / (width / height * np.tan(yfov / 2.0)), 0.0, 0.0, 0.0], + [0.0, 1.0 / np.tan(yfov / 2.0), 0.0, 0.0], + [0.0, 0.0, -1.0, -2.0 * znear], + [0.0, 0.0, -1.0, 0.0] + ]) + ) + + +def test_orthographic_camera(): + xm = 1.0 + ym = 2.0 + n = 0.05 + f = 100.0 + + with pytest.raises(TypeError): + c = OrthographicCamera() + + c = OrthographicCamera(xmag=xm, ymag=ym) + + assert c.xmag == xm + assert c.ymag == ym + assert c.znear == 0.05 + assert c.zfar == 100.0 + assert c.name is None + + with pytest.raises(TypeError): + c.ymag = None + + with pytest.raises(ValueError): + c.ymag = 0.0 + + with pytest.raises(ValueError): + c.ymag = -1.0 + + with pytest.raises(TypeError): + c.xmag = None + + with pytest.raises(ValueError): + c.xmag = 0.0 + + with pytest.raises(ValueError): + c.xmag = -1.0 + + with pytest.raises(TypeError): + c.znear = None + + with pytest.raises(ValueError): + c.znear = 0.0 + + with pytest.raises(ValueError): + c.znear = -1.0 + + with pytest.raises(ValueError): + c.zfar = 0.01 + + assert np.allclose( + c.get_projection_matrix(), + np.array([ + [1.0 / xm, 0, 0, 0], + [0, 1.0 / ym, 0, 0], + [0, 0, 2.0 / (n - f), (f + n) / (n - f)], + [0, 0, 0, 1.0] + ]) + ) diff --git a/pyrender/tests/unit/test_egl.py b/pyrender/tests/unit/test_egl.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f4bef39e33c2794e6837b5a1bb127d8d4dba06 --- /dev/null +++ b/pyrender/tests/unit/test_egl.py @@ -0,0 +1,16 @@ +# from pyrender.platforms import egl + + +def tmp_test_default_device(): + egl.get_default_device() + + +def tmp_test_query_device(): + devices = egl.query_devices() + assert len(devices) > 0 + + +def tmp_test_init_context(): + device = egl.query_devices()[0] + platform = egl.EGLPlatform(128, 128, device=device) + platform.init_context() diff --git a/pyrender/tests/unit/test_lights.py b/pyrender/tests/unit/test_lights.py new file mode 100644 index 0000000000000000000000000000000000000000..ffde856b21e8cce9532f0308fcd1c7eb2d1eba90 --- /dev/null +++ b/pyrender/tests/unit/test_lights.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest + +from pyrender import (DirectionalLight, SpotLight, PointLight, Texture, + PerspectiveCamera, OrthographicCamera) +from pyrender.constants import SHADOW_TEX_SZ + + +def test_directional_light(): + + d = DirectionalLight() + assert d.name is None + assert np.all(d.color == 1.0) + assert d.intensity == 1.0 + + d.name = 'direc' + with pytest.raises(ValueError): + d.color = None + with pytest.raises(TypeError): + d.intensity = None + + d = DirectionalLight(color=[0.0, 0.0, 0.0]) + assert np.all(d.color == 0.0) + + d._generate_shadow_texture() + st = d.shadow_texture + assert isinstance(st, Texture) + assert st.width == st.height == SHADOW_TEX_SZ + + sc = d._get_shadow_camera(scene_scale=5.0) + assert isinstance(sc, OrthographicCamera) + assert sc.xmag == sc.ymag == 5.0 + assert sc.znear == 0.01 * 5.0 + assert sc.zfar == 10 * 5.0 + + +def test_spot_light(): + + s = SpotLight() + assert s.name is None + assert np.all(s.color == 1.0) + assert s.intensity == 1.0 + assert s.innerConeAngle == 0.0 + assert s.outerConeAngle == np.pi / 4.0 + assert s.range is None + + with pytest.raises(ValueError): + s.range = -1.0 + + with pytest.raises(ValueError): + s.range = 0.0 + + with pytest.raises(ValueError): + s.innerConeAngle = -1.0 + + with pytest.raises(ValueError): + s.innerConeAngle = np.pi / 3.0 + + with pytest.raises(ValueError): + s.outerConeAngle = -1.0 + + with pytest.raises(ValueError): + s.outerConeAngle = np.pi + + s.range = 5.0 + s.outerConeAngle = np.pi / 2 - 0.05 + s.innerConeAngle = np.pi / 3 + s.innerConeAngle = 0.0 + s.outerConeAngle = np.pi / 4.0 + + s._generate_shadow_texture() + st = s.shadow_texture + assert isinstance(st, Texture) + assert st.width == st.height == SHADOW_TEX_SZ + + sc = s._get_shadow_camera(scene_scale=5.0) + assert isinstance(sc, PerspectiveCamera) + assert sc.znear == 0.01 * 5.0 + assert sc.zfar == 10 * 5.0 + assert sc.aspectRatio == 1.0 + assert np.allclose(sc.yfov, np.pi / 16.0 * 9.0) # Plus pi / 16 + + +def test_point_light(): + + s = PointLight() + assert s.name is None + assert np.all(s.color == 1.0) + assert s.intensity == 1.0 + assert s.range is None + + with pytest.raises(ValueError): + s.range = -1.0 + + with pytest.raises(ValueError): + s.range = 0.0 + + s.range = 5.0 + + with pytest.raises(NotImplementedError): + s._generate_shadow_texture() + + with pytest.raises(NotImplementedError): + s._get_shadow_camera(scene_scale=5.0) diff --git a/pyrender/tests/unit/test_meshes.py b/pyrender/tests/unit/test_meshes.py new file mode 100644 index 0000000000000000000000000000000000000000..7070b01171c97069fa013c6eba8eee217017f08e --- /dev/null +++ b/pyrender/tests/unit/test_meshes.py @@ -0,0 +1,133 @@ +import numpy as np +import pytest +import trimesh + +from pyrender import (Mesh, Primitive) + + +def test_meshes(): + + with pytest.raises(TypeError): + x = Mesh() + with pytest.raises(TypeError): + x = Primitive() + with pytest.raises(ValueError): + x = Primitive([], mode=10) + + # Basics + x = Mesh([]) + assert x.name is None + assert x.is_visible + assert x.weights is None + + x.name = 'str' + + # From Trimesh + x = Mesh.from_trimesh(trimesh.creation.box()) + assert isinstance(x, Mesh) + assert len(x.primitives) == 1 + assert x.is_visible + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -0.5], + [0.5, 0.5, 0.5] + ])) + assert np.allclose(x.centroid, np.zeros(3)) + assert np.allclose(x.extents, np.ones(3)) + assert np.allclose(x.scale, np.sqrt(3)) + assert not x.is_transparent + + # Test some primitive functions + x = x.primitives[0] + with pytest.raises(ValueError): + x.normals = np.zeros(10) + with pytest.raises(ValueError): + x.tangents = np.zeros(10) + with pytest.raises(ValueError): + x.texcoord_0 = np.zeros(10) + with pytest.raises(ValueError): + x.texcoord_1 = np.zeros(10) + with pytest.raises(TypeError): + x.material = np.zeros(10) + assert x.targets is None + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -0.5], + [0.5, 0.5, 0.5] + ])) + assert np.allclose(x.centroid, np.zeros(3)) + assert np.allclose(x.extents, np.ones(3)) + assert np.allclose(x.scale, np.sqrt(3)) + x.material.baseColorFactor = np.array([0.0, 0.0, 0.0, 0.0]) + assert x.is_transparent + + # From two trimeshes + x = Mesh.from_trimesh([trimesh.creation.box(), + trimesh.creation.cylinder(radius=0.1, height=2.0)], + smooth=False) + assert isinstance(x, Mesh) + assert len(x.primitives) == 2 + assert x.is_visible + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -1.0], + [0.5, 0.5, 1.0] + ])) + assert np.allclose(x.centroid, np.zeros(3)) + assert np.allclose(x.extents, [1.0, 1.0, 2.0]) + assert np.allclose(x.scale, np.sqrt(6)) + assert not x.is_transparent + + # From bad data + with pytest.raises(TypeError): + x = Mesh.from_trimesh(None) + + # With instancing + poses = np.tile(np.eye(4), (5,1,1)) + poses[:,0,3] = np.array([0,1,2,3,4]) + x = Mesh.from_trimesh(trimesh.creation.box(), poses=poses) + assert np.allclose(x.bounds, np.array([ + [-0.5, -0.5, -0.5], + [4.5, 0.5, 0.5] + ])) + poses = np.eye(4) + x = Mesh.from_trimesh(trimesh.creation.box(), poses=poses) + poses = np.eye(3) + with pytest.raises(ValueError): + x = Mesh.from_trimesh(trimesh.creation.box(), poses=poses) + + # From textured meshes + fm = trimesh.load('tests/data/fuze.obj') + x = Mesh.from_trimesh(fm) + assert isinstance(x, Mesh) + assert len(x.primitives) == 1 + assert x.is_visible + assert not x.is_transparent + assert x.primitives[0].material.baseColorTexture is not None + + x = Mesh.from_trimesh(fm, smooth=False) + fm.visual = fm.visual.to_color() + fm.visual.face_colors = np.array([1.0, 0.0, 0.0, 1.0]) + x = Mesh.from_trimesh(fm, smooth=False) + with pytest.raises(ValueError): + x = Mesh.from_trimesh(fm, smooth=True) + + fm.visual.vertex_colors = np.array([1.0, 0.0, 0.0, 0.5]) + x = Mesh.from_trimesh(fm, smooth=False) + x = Mesh.from_trimesh(fm, smooth=True) + assert x.primitives[0].color_0 is not None + assert x.is_transparent + + bm = trimesh.load('tests/data/WaterBottle.glb').dump()[0] + x = Mesh.from_trimesh(bm) + assert x.primitives[0].material.baseColorTexture is not None + assert x.primitives[0].material.emissiveTexture is not None + assert x.primitives[0].material.metallicRoughnessTexture is not None + + # From point cloud + x = Mesh.from_points(fm.vertices) + +# def test_duck(): +# bm = trimesh.load('tests/data/Duck.glb').dump()[0] +# x = Mesh.from_trimesh(bm) +# assert x.primitives[0].material.baseColorTexture is not None +# pixel = x.primitives[0].material.baseColorTexture.source[100, 100] +# yellowish = np.array([1.0, 0.7411765, 0.0, 1.0]) +# assert np.allclose(pixel, yellowish) diff --git a/pyrender/tests/unit/test_nodes.py b/pyrender/tests/unit/test_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..9857c8221b7f6fb8530699bdf5593f8f0b74e152 --- /dev/null +++ b/pyrender/tests/unit/test_nodes.py @@ -0,0 +1,124 @@ +import numpy as np +import pytest +from trimesh import transformations + +from pyrender import (DirectionalLight, PerspectiveCamera, Mesh, Node) + + +def test_nodes(): + + x = Node() + assert x.name is None + assert x.camera is None + assert x.children == [] + assert x.skin is None + assert np.allclose(x.matrix, np.eye(4)) + assert x.mesh is None + assert np.allclose(x.rotation, [0,0,0,1]) + assert np.allclose(x.scale, np.ones(3)) + assert np.allclose(x.translation, np.zeros(3)) + assert x.weights is None + assert x.light is None + + x.name = 'node' + + # Test node light/camera/mesh tests + c = PerspectiveCamera(yfov=2.0) + m = Mesh([]) + d = DirectionalLight() + x.camera = c + assert x.camera == c + with pytest.raises(TypeError): + x.camera = m + x.camera = d + x.camera = None + x.mesh = m + assert x.mesh == m + with pytest.raises(TypeError): + x.mesh = c + x.mesh = d + x.light = d + assert x.light == d + with pytest.raises(TypeError): + x.light = m + x.light = c + + # Test transformations getters/setters/etc... + # Set up test values + x = np.array([1.0, 0.0, 0.0]) + y = np.array([0.0, 1.0, 0.0]) + t = np.array([1.0, 2.0, 3.0]) + s = np.array([0.5, 2.0, 1.0]) + + Mx = transformations.rotation_matrix(np.pi / 2.0, x) + qx = np.roll(transformations.quaternion_about_axis(np.pi / 2.0, x), -1) + Mxt = Mx.copy() + Mxt[:3,3] = t + S = np.eye(4) + S[:3,:3] = np.diag(s) + Mxts = Mxt.dot(S) + + My = transformations.rotation_matrix(np.pi / 2.0, y) + qy = np.roll(transformations.quaternion_about_axis(np.pi / 2.0, y), -1) + Myt = My.copy() + Myt[:3,3] = t + + x = Node(matrix=Mx) + assert np.allclose(x.matrix, Mx) + assert np.allclose(x.rotation, qx) + assert np.allclose(x.translation, np.zeros(3)) + assert np.allclose(x.scale, np.ones(3)) + + x.matrix = My + assert np.allclose(x.matrix, My) + assert np.allclose(x.rotation, qy) + assert np.allclose(x.translation, np.zeros(3)) + assert np.allclose(x.scale, np.ones(3)) + x.translation = t + assert np.allclose(x.matrix, Myt) + assert np.allclose(x.rotation, qy) + x.rotation = qx + assert np.allclose(x.matrix, Mxt) + x.scale = s + assert np.allclose(x.matrix, Mxts) + + x = Node(matrix=Mxt) + assert np.allclose(x.matrix, Mxt) + assert np.allclose(x.rotation, qx) + assert np.allclose(x.translation, t) + assert np.allclose(x.scale, np.ones(3)) + + x = Node(matrix=Mxts) + assert np.allclose(x.matrix, Mxts) + assert np.allclose(x.rotation, qx) + assert np.allclose(x.translation, t) + assert np.allclose(x.scale, s) + + # Individual element getters + x.scale[0] = 0 + assert np.allclose(x.scale[0], 0) + + x.translation[0] = 0 + assert np.allclose(x.translation[0], 0) + + x.matrix = np.eye(4) + x.matrix[0,0] = 500 + assert x.matrix[0,0] == 1.0 + + # Failures + with pytest.raises(ValueError): + x.matrix = 5 * np.eye(4) + with pytest.raises(ValueError): + x.matrix = np.eye(5) + with pytest.raises(ValueError): + x.matrix = np.eye(4).dot([5,1,1,1]) + with pytest.raises(ValueError): + x.rotation = np.array([1,2]) + with pytest.raises(ValueError): + x.rotation = np.array([1,2,3]) + with pytest.raises(ValueError): + x.rotation = np.array([1,2,3,4]) + with pytest.raises(ValueError): + x.translation = np.array([1,2,3,4]) + with pytest.raises(ValueError): + x.scale = np.array([1,2,3,4]) diff --git a/pyrender/tests/unit/test_offscreen.py b/pyrender/tests/unit/test_offscreen.py new file mode 100644 index 0000000000000000000000000000000000000000..88983b0ff4e2ab6f5ef252c51f2ac669c3a0e0ca --- /dev/null +++ b/pyrender/tests/unit/test_offscreen.py @@ -0,0 +1,92 @@ +import numpy as np +import trimesh + +from pyrender import (OffscreenRenderer, PerspectiveCamera, DirectionalLight, + SpotLight, Mesh, Node, Scene) + + +def test_offscreen_renderer(tmpdir): + + # Fuze trimesh + fuze_trimesh = trimesh.load('examples/models/fuze.obj') + fuze_mesh = Mesh.from_trimesh(fuze_trimesh) + + # Drill trimesh + drill_trimesh = trimesh.load('examples/models/drill.obj') + drill_mesh = Mesh.from_trimesh(drill_trimesh) + drill_pose = np.eye(4) + drill_pose[0,3] = 0.1 + drill_pose[2,3] = -np.min(drill_trimesh.vertices[:,2]) + + # Wood trimesh + wood_trimesh = trimesh.load('examples/models/wood.obj') + wood_mesh = Mesh.from_trimesh(wood_trimesh) + + # Water bottle trimesh + bottle_gltf = trimesh.load('examples/models/WaterBottle.glb') + bottle_trimesh = bottle_gltf.geometry[list(bottle_gltf.geometry.keys())[0]] + bottle_mesh = Mesh.from_trimesh(bottle_trimesh) + bottle_pose = np.array([ + [1.0, 0.0, 0.0, 0.1], + [0.0, 0.0, -1.0, -0.16], + [0.0, 1.0, 0.0, 0.13], + [0.0, 0.0, 0.0, 1.0], + ]) + + boxv_trimesh = trimesh.creation.box(extents=0.1 * np.ones(3)) + boxv_vertex_colors = np.random.uniform(size=(boxv_trimesh.vertices.shape)) + boxv_trimesh.visual.vertex_colors = boxv_vertex_colors + boxv_mesh = Mesh.from_trimesh(boxv_trimesh, smooth=False) + boxf_trimesh = trimesh.creation.box(extents=0.1 * np.ones(3)) + boxf_face_colors = np.random.uniform(size=boxf_trimesh.faces.shape) + boxf_trimesh.visual.face_colors = boxf_face_colors + # Instanced + poses = np.tile(np.eye(4), (2,1,1)) + poses[0,:3,3] = np.array([-0.1, -0.10, 0.05]) + poses[1,:3,3] = np.array([-0.15, -0.10, 0.05]) + boxf_mesh = Mesh.from_trimesh(boxf_trimesh, poses=poses, smooth=False) + + points = trimesh.creation.icosphere(radius=0.05).vertices + point_colors = np.random.uniform(size=points.shape) + points_mesh = Mesh.from_points(points, colors=point_colors) + + direc_l = DirectionalLight(color=np.ones(3), intensity=1.0) + spot_l = SpotLight(color=np.ones(3), intensity=10.0, + innerConeAngle=np.pi / 16, outerConeAngle=np.pi / 6) + + cam = PerspectiveCamera(yfov=(np.pi / 3.0)) + cam_pose = np.array([ + [0.0, -np.sqrt(2) / 2, np.sqrt(2) / 2, 0.5], + [1.0, 0.0, 0.0, 0.0], + [0.0, np.sqrt(2) / 2, np.sqrt(2) / 2, 0.4], + [0.0, 0.0, 0.0, 1.0] + ]) + + scene = Scene(ambient_light=np.array([0.02, 0.02, 0.02])) + + fuze_node = Node(mesh=fuze_mesh, translation=np.array([ + 0.1, 0.15, -np.min(fuze_trimesh.vertices[:,2]) + ])) + scene.add_node(fuze_node) + boxv_node = Node(mesh=boxv_mesh, translation=np.array([-0.1, 0.10, 0.05])) + scene.add_node(boxv_node) + boxf_node = Node(mesh=boxf_mesh) + scene.add_node(boxf_node) + + _ = scene.add(drill_mesh, pose=drill_pose) + _ = scene.add(bottle_mesh, pose=bottle_pose) + _ = scene.add(wood_mesh) + _ = scene.add(direc_l, pose=cam_pose) + _ = scene.add(spot_l, pose=cam_pose) + _ = scene.add(points_mesh) + + _ = scene.add(cam, pose=cam_pose) + + r = OffscreenRenderer(viewport_width=640, viewport_height=480) + color, depth = r.render(scene) + + assert color.shape == (480, 640, 3) + assert depth.shape == (480, 640) + assert np.max(depth.data) > 0.05 + assert np.count_nonzero(depth.data) > (0.2 * depth.size) + r.delete() diff --git a/pyrender/tests/unit/test_scenes.py b/pyrender/tests/unit/test_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..d85dd714cb5d842ea12dee4140adfd7db55c9c01 --- /dev/null +++ b/pyrender/tests/unit/test_scenes.py @@ -0,0 +1,235 @@ +import numpy as np +import pytest +import trimesh + +from pyrender import (Mesh, PerspectiveCamera, DirectionalLight, + SpotLight, PointLight, Scene, Node, OrthographicCamera) + + +def test_scenes(): + + # Basics + s = Scene() + assert np.allclose(s.bg_color, np.ones(4)) + assert np.allclose(s.ambient_light, np.zeros(3)) + assert len(s.nodes) == 0 + assert s.name is None + s.name = 'asdf' + s.bg_color = None + s.ambient_light = None + assert np.allclose(s.bg_color, np.ones(4)) + assert np.allclose(s.ambient_light, np.zeros(3)) + + assert s.nodes == set() + assert s.cameras == set() + assert s.lights == set() + assert s.point_lights == set() + assert s.spot_lights == set() + assert s.directional_lights == set() + assert s.meshes == set() + assert s.camera_nodes == set() + assert s.light_nodes == set() + assert s.point_light_nodes == set() + assert s.spot_light_nodes == set() + assert s.directional_light_nodes == set() + assert s.mesh_nodes == set() + assert s.main_camera_node is None + assert np.all(s.bounds == 0) + assert np.all(s.centroid == 0) + assert np.all(s.extents == 0) + assert np.all(s.scale == 0) + + # From trimesh scene + tms = trimesh.load('tests/data/WaterBottle.glb') + s = Scene.from_trimesh_scene(tms) + assert len(s.meshes) == 1 + assert len(s.mesh_nodes) == 1 + + # Test bg color formatting + s = Scene(bg_color=[0, 1.0, 0]) + assert np.allclose(s.bg_color, np.array([0.0, 1.0, 0.0, 1.0])) + + # Test constructor for nodes + n1 = Node() + n2 = Node() + n3 = Node() + nodes = [n1, n2, n3] + s = Scene(nodes=nodes) + n1.children.append(n2) + s = Scene(nodes=nodes) + n3.children.append(n2) + with pytest.raises(ValueError): + s = Scene(nodes=nodes) + n3.children = [] + n2.children.append(n3) + n3.children.append(n2) + with pytest.raises(ValueError): + s = Scene(nodes=nodes) + + # Test node accessors + n1 = Node() + n2 = Node() + n3 = Node() + nodes = [n1, n2] + s = Scene(nodes=nodes) + assert s.has_node(n1) + assert s.has_node(n2) + assert not s.has_node(n3) + + # Test node poses + for n in nodes: + assert np.allclose(s.get_pose(n), np.eye(4)) + with pytest.raises(ValueError): + s.get_pose(n3) + with pytest.raises(ValueError): + s.set_pose(n3, np.eye(4)) + tf = np.eye(4) + tf[:3,3] = np.ones(3) + s.set_pose(n1, tf) + assert np.allclose(s.get_pose(n1), tf) + assert np.allclose(s.get_pose(n2), np.eye(4)) + + nodes = [n1, n2, n3] + tf2 = np.eye(4) + tf2[:3,:3] = np.diag([-1,-1,1]) + n1.children.append(n2) + n1.matrix = tf + n2.matrix = tf2 + s = Scene(nodes=nodes) + assert np.allclose(s.get_pose(n1), tf) + assert np.allclose(s.get_pose(n2), tf.dot(tf2)) + assert np.allclose(s.get_pose(n3), np.eye(4)) + + n1 = Node() + n2 = Node() + n3 = Node() + n1.children.append(n2) + s = Scene() + s.add_node(n1) + with pytest.raises(ValueError): + s.add_node(n2) + s.set_pose(n1, tf) + assert np.allclose(s.get_pose(n1), tf) + assert np.allclose(s.get_pose(n2), tf) + s.set_pose(n2, tf2) + assert np.allclose(s.get_pose(n2), tf.dot(tf2)) + + # Test node removal + n1 = Node() + n2 = Node() + n3 = Node() + n1.children.append(n2) + n2.children.append(n3) + s = Scene(nodes=[n1, n2, n3]) + s.remove_node(n2) + assert len(s.nodes) == 1 + assert n1 in s.nodes + assert len(n1.children) == 0 + assert len(n2.children) == 1 + s.add_node(n2, parent_node=n1) + assert len(n1.children) == 1 + n1.matrix = tf + n3.matrix = tf2 + assert np.allclose(s.get_pose(n3), tf.dot(tf2)) + + # Now test ADD function + s = Scene() + m = Mesh([], name='m') + cp = PerspectiveCamera(yfov=2.0) + co = OrthographicCamera(xmag=1.0, ymag=1.0) + dl = DirectionalLight() + pl = PointLight() + sl = SpotLight() + + n1 = s.add(m, name='mn') + assert n1.mesh == m + assert len(s.nodes) == 1 + assert len(s.mesh_nodes) == 1 + assert n1 in s.mesh_nodes + assert len(s.meshes) == 1 + assert m in s.meshes + assert len(s.get_nodes(node=n2)) == 0 + n2 = s.add(m, pose=tf) + assert len(s.nodes) == len(s.mesh_nodes) == 2 + assert len(s.meshes) == 1 + assert len(s.get_nodes(node=n1)) == 1 + assert len(s.get_nodes(node=n1, name='mn')) == 1 + assert len(s.get_nodes(name='mn')) == 1 + assert len(s.get_nodes(obj=m)) == 2 + assert len(s.get_nodes(obj=m, obj_name='m')) == 2 + assert len(s.get_nodes(obj=co)) == 0 + nsl = s.add(sl, name='sln') + npl = s.add(pl, parent_name='sln') + assert nsl.children[0] == npl + ndl = s.add(dl, parent_node=npl) + assert npl.children[0] == ndl + nco = s.add(co) + ncp = s.add(cp) + + assert len(s.light_nodes) == len(s.lights) == 3 + assert len(s.point_light_nodes) == len(s.point_lights) == 1 + assert npl in s.point_light_nodes + assert len(s.spot_light_nodes) == len(s.spot_lights) == 1 + assert nsl in s.spot_light_nodes + assert len(s.directional_light_nodes) == len(s.directional_lights) == 1 + assert ndl in s.directional_light_nodes + assert len(s.cameras) == len(s.camera_nodes) == 2 + assert s.main_camera_node == nco + s.main_camera_node = ncp + s.remove_node(ncp) + assert len(s.cameras) == len(s.camera_nodes) == 1 + assert s.main_camera_node == nco + s.remove_node(n2) + assert len(s.meshes) == 1 + s.remove_node(n1) + assert len(s.meshes) == 0 + s.remove_node(nsl) + assert len(s.lights) == 0 + s.remove_node(nco) + assert s.main_camera_node is None + + s.add_node(n1) + s.clear() + assert len(s.nodes) == 0 + + # Trigger final errors + with pytest.raises(ValueError): + s.main_camera_node = None + with pytest.raises(ValueError): + s.main_camera_node = ncp + with pytest.raises(ValueError): + s.add(m, parent_node=n1) + with pytest.raises(ValueError): + s.add(m, name='asdf') + s.add(m, name='asdf') + s.add(m, parent_name='asdf') + with pytest.raises(ValueError): + s.add(m, parent_name='asfd') + with pytest.raises(TypeError): + s.add(None) + + s.clear() + # Test bounds + m1 = Mesh.from_trimesh(trimesh.creation.box()) + m2 = Mesh.from_trimesh(trimesh.creation.box()) + m3 = Mesh.from_trimesh(trimesh.creation.box()) + n1 = Node(mesh=m1) + n2 = Node(mesh=m2, translation=[1.0, 0.0, 0.0]) + n3 = Node(mesh=m3, translation=[0.5, 0.0, 1.0]) + s.add_node(n1) + s.add_node(n2) + s.add_node(n3) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [1.5, 0.5, 1.5]]) + s.clear() + s.add_node(n1) + s.add_node(n2, parent_node=n1) + s.add_node(n3, parent_node=n2) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [2.0, 0.5, 1.5]]) + tf = np.eye(4) + tf[:3,3] = np.ones(3) + s.set_pose(n3, tf) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [2.5, 1.5, 1.5]]) + s.remove_node(n2) + assert np.allclose(s.bounds, [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]]) + s.clear() + assert np.allclose(s.bounds, 0.0) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ba1fcbfec8acbde3f407ba282affacd91b72b95 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,27 @@ +git+https://github.com/openai/CLIP.git +numpy==1.23.3 +matplotlib==3.4.3 +matplotlib-inline==0.1.2 +transformers +h5py +smplx +shapely +freetype-py +imageio +networkx +numpy +Pillow +pyglet==1.4.0a1 +PyOpenGL +PyOpenGL_accelerate +six +trimesh +sphinx +sphinx_rtd_theme +sphinx-automodapi +mapbox_earcut +chumpy +gdown +MoviePy +ffmpeg +gradio==3.12 \ No newline at end of file