diff --git a/checkpoints/charset.txt b/checkpoints/charset.txt new file mode 100644 index 0000000000000000000000000000000000000000..525d4fec557fdabb296f89137efe5328d1353756 --- /dev/null +++ b/checkpoints/charset.txt @@ -0,0 +1 @@ +[" ", "!", "\"", "#", "$", "%", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ":", ";", "<", "=", ">", "?", "@", "A", "B", "C", "D", "E", "F", "FI", "G", "H", "I", "I\u0307", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "SS", "T", "U", "V", "W", "X", "Y", "Z", "[", "\\", "]", "^", "_", "`", "a", "b", "c", "d", "e", "f", "fi", "g", "h", "i", "i\u0307", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "ss", "t", "u", "v", "w", "x", "y", "z", "{", "|", "}", "~", "\u00b2", "\u00b3", "\u00b5", "\u00b9", "\u00ba", "\u00c0", "\u00c1", "\u00c2", "\u00c3", "\u00c4", "\u00c5", "\u00c6", "\u00c7", "\u00c8", "\u00c9", "\u00ca", "\u00cb", "\u00cc", "\u00cd", "\u00ce", "\u00cf", "\u00d0", "\u00d1", "\u00d2", "\u00d3", "\u00d4", "\u00d5", "\u00d6", "\u00d8", "\u00d9", "\u00da", "\u00db", "\u00dc", "\u00dd", "\u00de", "\u00df", "\u00e0", "\u00e1", "\u00e2", "\u00e3", "\u00e4", "\u00e5", "\u00e6", "\u00e7", "\u00e8", "\u00e9", "\u00ea", "\u00eb", "\u00ec", "\u00ed", "\u00ee", "\u00ef", "\u00f0", "\u00f1", "\u00f2", "\u00f3", "\u00f4", "\u00f5", "\u00f6", "\u00f8", "\u00f9", "\u00fa", "\u00fb", "\u00fc", "\u00fd", "\u00fe", "\u00ff", "\u0100", "\u0101", "\u0102", "\u0103", "\u0104", "\u0105", "\u0106", "\u0107", "\u010c", "\u010d", "\u010e", "\u010f", "\u0110", "\u0111", "\u0112", "\u0113", "\u0116", "\u0117", "\u0118", "\u0119", "\u011a", "\u011b", "\u011e", "\u011f", "\u0120", "\u0121", "\u0126", "\u0127", "\u0128", "\u0129", "\u012a", "\u012b", "\u0130", "\u0131", "\u0136", "\u0137", "\u013d", "\u013e", "\u0141", "\u0142", "\u0143", "\u0144", "\u0145", "\u0146", "\u0147", "\u0148", "\u014a", "\u014b", "\u014c", "\u014d", "\u014e", "\u014f", "\u0150", "\u0151", "\u0152", "\u0153", "\u0158", "\u0159", "\u015a", "\u015b", "\u015e", "\u015f", "\u0160", "\u0161", "\u0162", "\u0163", "\u0164", "\u0165", "\u0168", "\u0169", "\u016a", "\u016b", "\u016c", "\u016d", "\u016e", "\u016f", "\u0172", "\u0173", "\u0174", "\u0175", "\u0176", "\u0177", "\u0178", "\u0179", "\u017a", "\u017b", "\u017c", "\u017d", "\u017e", "\u0181", "\u0186", "\u0189", "\u018a", "\u018f", "\u0190", "\u0191", "\u0192", "\u0194", "\u0197", "\u019c", "\u019d", "\u019f", "\u01a0", "\u01a1", "\u01a6", "\u01a9", "\u01ae", "\u01af", "\u01b0", "\u01b1", "\u01b2", "\u01b7", "\u01c2", "\u01cd", "\u01ce", "\u01cf", "\u01d0", "\u01d1", "\u01d2", "\u01d3", "\u01d4", "\u01ea", "\u01eb", "\u0218", "\u0219", "\u021a", "\u021b", "\u0245", "\u0250", "\u0251", "\u0252", "\u0253", "\u0254", "\u0255", "\u0256", "\u0257", "\u0259", "\u025b", "\u025f", "\u0261", "\u0262", "\u0263", "\u0266", "\u0267", "\u0268", "\u026a", "\u026c", "\u026f", "\u0272", "\u0274", "\u0275", "\u0278", "\u027b", "\u027e", "\u0280", "\u0281", "\u0282", "\u0283", "\u0287", "\u0288", "\u028a", "\u028b", "\u028c", "\u028d", "\u028e", "\u0292", "\u0294", "\u0295", "\u0298", "\u029d", "\u029f", "\u02b0", "\u02b2", "\u02b7", "\u02bb", "\u02bc", "\u02be", "\u02bf", "\u02c0", "\u02c1", "\u02c8", "\u02cc", "\u02d0", "\u02e0", "\u02e4", "\u0386", "\u0388", "\u038a", "\u038c", "\u038e", "\u038f", "\u0391", "\u0391\u0342", "\u0392", "\u0393", "\u0394", "\u0395", "\u0396", "\u0397", "\u0397\u0342", "\u0398", "\u0399", "\u0399\u0342", "\u039a", "\u039b", "\u039c", "\u039d", "\u039e", "\u039f", "\u03a0", "\u03a1", "\u03a3", "\u03a4", "\u03a5", "\u03a5\u0313", "\u03a5\u0342", "\u03a6", "\u03a7", "\u03a8", "\u03a9", "\u03a9\u0342", "\u03a9\u0342\u0399", "\u03ac", "\u03ad", "\u03af", "\u03b1", "\u03b1\u0342", "\u03b2", "\u03b3", "\u03b4", "\u03b5", "\u03b6", "\u03b7", "\u03b7\u0342", "\u03b8", "\u03b9", "\u03b9\u0342", "\u03ba", "\u03bb", "\u03bc", "\u03bd", "\u03be", "\u03bf", "\u03c0", "\u03c1", "\u03c2", "\u03c3", "\u03c4", "\u03c5", "\u03c5\u0313", "\u03c5\u0342", "\u03c6", "\u03c7", "\u03c8", "\u03c9", "\u03c9\u0342", "\u03c9\u0342\u03b9", "\u03cc", "\u03cd", "\u03ce", "\u03d5", "\u0401", "\u0406", "\u0408", "\u0410", "\u0411", "\u0412", "\u0413", "\u0414", "\u0415", "\u0416", "\u0417", "\u0418", "\u0419", "\u041a", "\u041b", "\u041c", "\u041d", "\u041e", "\u041f", "\u0420", "\u0421", "\u0422", "\u0423", "\u0425", "\u0426", "\u0427", "\u0428", "\u042a", "\u042b", "\u042c", "\u042d", "\u042e", "\u042f", "\u0430", "\u0431", "\u0432", "\u0433", "\u0434", "\u0435", "\u0436", "\u0437", "\u0438", "\u0439", "\u043a", "\u043b", "\u043c", "\u043d", "\u043e", "\u043f", "\u0440", "\u0441", "\u0442", "\u0443", "\u0445", "\u0446", "\u0447", "\u0448", "\u044a", "\u044b", "\u044c", "\u044d", "\u044e", "\u044f", "\u0451", "\u0456", "\u0458", "\u05b5", "\u05b6", "\u05bc", "\u05d0", "\u05d1", "\u05d2", "\u05d3", "\u05d5", "\u05d7", "\u05d9", "\u05dc", "\u05dd", "\u05de", "\u05e0", "\u05e1", "\u05e2", "\u05e6", "\u05e8", "\u05e9", "\u05ea", "\u0621", "\u0623", "\u0625", "\u0627", "\u0628", "\u0629", "\u062a", "\u062c", "\u062d", "\u062e", "\u062f", "\u0631", "\u0632", "\u0633", "\u0634", "\u0635", "\u0637", "\u0639", "\u063a", "\u0641", "\u0642", "\u0643", "\u0644", "\u0645", "\u0646", "\u0647", "\u0648", "\u064a", "\u06cc", "\u0902", "\u0905", "\u0906", "\u0909", "\u0915", "\u0917", "\u091f", "\u0921", "\u0924", "\u0926", "\u0928", "\u092a", "\u092c", "\u092d", "\u092e", "\u092f", "\u0930", "\u0932", "\u0936", "\u0937", "\u0938", "\u0939", "\u093e", "\u093f", "\u0940", "\u0947", "\u094b", "\u0995", "\u09a4", "\u09b2", "\u09be", "\u09bf", "\u0b95", "\u0ba9", "\u0bb3", "\u0e02", "\u0e07", "\u0e08", "\u0e0a", "\u0e10", "\u0e15", "\u0e17", "\u0e19", "\u0e1b", "\u0e1e", "\u0e23", "\u0e27", "\u0e30", "\u0e31", "\u0e32", "\u0e40", "\u0e41", "\u16c3", "\u16cb", "\u16df", "\u1e0c", "\u1e0d", "\u1e24", "\u1e25", "\u1e36", "\u1e37", "\u1e3a", "\u1e3b", "\u1e42", "\u1e43", "\u1e44", "\u1e45", "\u1e46", "\u1e47", "\u1e48", "\u1e49", "\u1e5a", "\u1e5b", "\u1e5e", "\u1e5f", "\u1e62", "\u1e63", "\u1e6c", "\u1e6d", "\u1e6e", "\u1e6f", "\u1ea0", "\u1ea1", "\u1ea2", "\u1ea3", "\u1ea4", "\u1ea5", "\u1ea6", "\u1ea7", "\u1ea8", "\u1ea9", "\u1eaa", "\u1eab", "\u1eac", "\u1ead", "\u1eae", "\u1eaf", "\u1eb4", "\u1eb5", "\u1eb6", "\u1eb7", "\u1eb8", "\u1eb9", "\u1ebe", "\u1ebf", "\u1ec2", "\u1ec3", "\u1ec4", "\u1ec5", "\u1ec6", "\u1ec7", "\u1eca", "\u1ecb", "\u1ecc", "\u1ecd", "\u1ece", "\u1ecf", "\u1ed0", "\u1ed1", "\u1ed2", "\u1ed3", "\u1ed4", "\u1ed5", "\u1ed6", "\u1ed7", "\u1ed8", "\u1ed9", "\u1eda", "\u1edb", "\u1edc", "\u1edd", "\u1ede", "\u1edf", "\u1ee2", "\u1ee3", "\u1ee4", "\u1ee5", "\u1ee6", "\u1ee7", "\u1ee8", "\u1ee9", "\u1eea", "\u1eeb", "\u1eec", "\u1eed", "\u1eee", "\u1eef", "\u1ef0", "\u1ef1", "\u1ef2", "\u1ef3", "\u1ef4", "\u1ef5", "\u1ef8", "\u1ef9", "\u1f00", "\u1f04", "\u1f08", "\u1f0c", "\u1f10", "\u1f15", "\u1f18", "\u1f1d", "\u1f20", "\u1f21", "\u1f28", "\u1f29", "\u1f30", "\u1f31", "\u1f38", "\u1f39", "\u1f41", "\u1f44", "\u1f49", "\u1f4c", "\u1f50", "\u1f51", "\u1f59", "\u1f61", "\u1f69", "\u1f70", "\u1f72", "\u1f74", "\u1f76", "\u1f78", "\u1f7a", "\u1f7c", "\u1fb6", "\u1fba", "\u1fc6", "\u1fc8", "\u1fca", "\u1fd6", "\u1fda", "\u1fe6", "\u1fea", "\u1ff6", "\u1ff7", "\u1ff8", "\u1ffa", "\u2081", "\u2082", "\u2083", "\u2113", "\u2460", "\u2461", "\u2463", "\u2c6d", "\u2c6f", "\u2c70", "\u3044", "\u3045", "\u3046", "\u304a", "\u304b", "\u304d", "\u304f", "\u3050", "\u3053", "\u3057", "\u3059", "\u305b", "\u305f", "\u3064", "\u3069", "\u306e", "\u3070", "\u307d", "\u3088", "\u3089", "\u3093", "\u30a1", "\u30a2", "\u30a3", "\u30a4", "\u30a6", "\u30a7", "\u30a8", "\u30a9", "\u30aa", "\u30ab", "\u30ac", "\u30af", "\u30b0", "\u30b3", "\u30b4", "\u30b5", "\u30b6", "\u30b7", "\u30b8", "\u30b9", "\u30ba", "\u30bb", "\u30bc", "\u30bd", "\u30bf", "\u30c1", "\u30c3", "\u30c4", "\u30c6", "\u30c7", "\u30c8", "\u30c9", "\u30ca", "\u30cb", "\u30ce", "\u30cf", "\u30d0", "\u30d1", "\u30d2", "\u30d3", "\u30d5", "\u30d6", "\u30d7", "\u30d9", "\u30da", "\u30dc", "\u30de", "\u30df", "\u30e1", "\u30e3", "\u30e4", "\u30e5", "\u30e6", "\u30e9", "\u30ea", "\u30eb", "\u30ec", "\u30ed", "\u30ef", "\u30f3", "\u30f4", "\u30fc", "\ua7aa", "\ua7ac", "\ua7ad", "\ua7ae", "\ua7b1", "\ua7b2", "\ua7c5", "\uac70", "\ub9c8", "\ub9c9", "\ub9d0", "\uc0ac", "\uc778", "\uc804", "\uc9c0", "\uc9d3", "\ud22c", "\ufb01"] \ No newline at end of file diff --git a/checkpoints/detector.pth b/checkpoints/detector.pth new file mode 100644 index 0000000000000000000000000000000000000000..ddb8791ef3af4ab4828d38a04ac1a92ab177bc63 --- /dev/null +++ b/checkpoints/detector.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7d50c74b2dba9acb8dd76d2fbcf75e6eeae0cb3e9688edf42c91aa5550ade1 +size 181677320 diff --git a/checkpoints/recognizer.pth b/checkpoints/recognizer.pth new file mode 100644 index 0000000000000000000000000000000000000000..468d90821a2fe461ba48e015d71910d80480c232 --- /dev/null +++ b/checkpoints/recognizer.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db307d9b0dcb6cd15ab6c71e302fd62ca90ce077c3013c9f63a4ba0dbfdf3f50 +size 19823477 diff --git a/checkpoints/relational.pth b/checkpoints/relational.pth new file mode 100644 index 0000000000000000000000000000000000000000..71ec4aa18bb1bb199172b08552ed6a4a38309830 --- /dev/null +++ b/checkpoints/relational.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1db5a62853269aabd8a040eeb05038a871032e8275def77653631657cb8ca4a +size 9048309 diff --git a/example.py b/example.py new file mode 100644 index 0000000000000000000000000000000000000000..7f6195fd7a9156992e73c7535924671c77989aa2 --- /dev/null +++ b/example.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from nemo_retriever_ocr.inference.pipeline import NemoRetrieverOCR + + +def main(image_path, merge_level, no_visualize, model_dir): + ocr_pipeline = NemoRetrieverOCR() + + predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize) + + print(f"Found {len(predictions)} text regions.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run OCR inference and annotate image.") + parser.add_argument("image_path", type=str, help="Path to the input image.") + parser.add_argument( + "--merge-level", + type=str, + choices=["word", "sentence", "paragraph"], + default="paragraph", + help="Merge level for OCR output (word, sentence, paragraph).", + ) + parser.add_argument("--no-visualize", action="store_true", help="Do not save the annotated image.") + parser.add_argument( + "--model-dir", + type=str, + help="Path to the model checkpoints.", + default="./checkpoints", + ) + args = parser.parse_args() + + main( + args.image_path, + merge_level=args.merge_level, + no_visualize=args.no_visualize, + model_dir=args.model_dir, + ) diff --git a/nemo-retriever-ocr/cpp/.gitattributes b/nemo-retriever-ocr/cpp/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..065a67c6a2547fb676a54ab003a1eb552fd3f558 --- /dev/null +++ b/nemo-retriever-ocr/cpp/.gitattributes @@ -0,0 +1 @@ +load_png/wuffs-v0.3.c filter=lfs diff=lfs merge=lfs -text diff --git a/nemo-retriever-ocr/cpp/.gitignore b/nemo-retriever-ocr/cpp/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e0d60bc592acf67adb5a1ade34c29c776f520845 --- /dev/null +++ b/nemo-retriever-ocr/cpp/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +.vscode +build +*.egg-info +dist +.vs diff --git a/nemo-retriever-ocr/cpp/.gitmodules b/nemo-retriever-ocr/cpp/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..f646fb91ddf80d588f9b36f7c36624c415599ca6 --- /dev/null +++ b/nemo-retriever-ocr/cpp/.gitmodules @@ -0,0 +1,3 @@ +[submodule "trove"] + path = trove + url = https://github.com/bryancatanzaro/trove.git diff --git a/nemo-retriever-ocr/cpp/README.md b/nemo-retriever-ocr/cpp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f6af661f56eca3d3f43217a9542eda488ea6cb16 --- /dev/null +++ b/nemo-retriever-ocr/cpp/README.md @@ -0,0 +1,15 @@ +# Optimized Image Operations for PyTorch + +## Installation + +``` +python setup.py install +``` + +## Usage + +``` +# It's important that you do this first +import torch +from pytorch_image_ops import color_transform, spatial_transform +``` diff --git a/nemo-retriever-ocr/cpp/beam_decode/beam_decode.cpp b/nemo-retriever-ocr/cpp/beam_decode/beam_decode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13b8be9928dfa0cdcc4565a698d44c6c747d15a8 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/beam_decode.cpp @@ -0,0 +1,460 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "beam_decode.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "prefix.h" +#include "log_sum_exp.h" +#include "sbo_lm.h" + +using namespace std; + +template +using pred_seq_t = torch::TensorAccessor; + +struct PrefixScore +{ + float_t lProbBlank; + float_t lProbChar; + // float_t raw_lProbBlank; + // float_t raw_lProbChar; + mutable float_t _lProb; + + PrefixScore(float_t lProbBlank = NEG_INF /* log P(0) */, float_t lProbChar = NEG_INF /* log P(0) */) + : lProbBlank(lProbBlank), lProbChar(lProbChar), _lProb(NEG_INF) + // , raw_lProbBlank(lProbBlank), raw_lProbChar(lProbChar) + {} + + float_t get_lScore() const { + if (_lProb == NEG_INF) { + _lProb = log_sum_exp(lProbBlank, lProbChar); + } + return _lProb; + } + + // float_t get_raw_lScore() const { + // return log_sum_exp(raw_lProbBlank, raw_lProbChar); + // } +}; + +typedef std::unordered_map PrefixMap; +typedef std::pair BeamItem; +typedef std::vector Beam; + +/* + Allows us to get an estimate of the vision model confidence, irrespective of how the language + model guided the decoding. NOTE: This scoring could follow an entirely different path than + the returned decoded sequence. +*/ +template +scalar_t get_vision_confidence(const pred_seq_t &logProbs, scalar_t minProb) +{ + const int64_t T = logProbs.size(0); + const int64_t S = logProbs.size(1); + + scalar_t ret = 0; // log(1) + + for (size_t t = 0; t < T; ++t) { + float_t maxP = logProbs[t][0]; + int64_t maxC = 0; + for (int64_t c = 1; c < S; ++c) { + float_t p = logProbs[t][c]; + if (p > maxP) { + maxP = p; + maxC = c; + } + } + ret += maxP; + // Ignore everything past the sequence terminator + if (maxC == 1) { + break; + } + + if (ret < minProb) { + break; + } + } + + return ret; +} + + +template +pair, float_t> + ctc_beam_decode_impl(const pred_seq_t &probs, const int64_t beamSize, + const int64_t blank, scalar_t minProb, + const LanguageModel &langModel, scalar_t lmWeight) +{ + if (blank != 0) { + throw runtime_error("Currently, only ordinal 0 supported for the blank prediction"); + } + + const int64_t T = probs.size(0); + const int64_t S = probs.size(1); + + // NOTE: In log space, the following is true: + // 1. Adding two probabilities: log_sum_exp(l_p_a, l_p_b) + // 2. Multiplying two probabilities: l_p_a + l_p_b + // 3. log P(0) = -inf + // 4. log P(1) = 0 + + // Convert to log-space + if (minProb > 0) { + minProb = log(minProb); + } else { + minProb = NEG_INF; + } + + auto retScore = get_vision_confidence(probs, minProb); + + if (retScore < minProb) { + return { {}, NEG_INF }; + } + + PrefixAllocator prefixAlloc; + + Beam beam; + beam.emplace_back(prefixAlloc.GetPrefix(), PrefixScore{0, NEG_INF}); // Add a dummy first node + + Beam terminated; + + typedef tuple lm_cache_key_t; + unordered_map lmScoreCache; + + for (int64_t t = 0; t < T; ++t) { + PrefixMap nextBeam; + + // Add all of the completed paths to the next beam. + // This allows us to accumulate new paths into these, + // but otherwise not process them + for (const BeamItem &prevNode : beam) { + if (prevNode.first->Token == 1) { + nextBeam.insert(prevNode); + } + } + + // Loop over vocab + for (int64_t s = 0; s < S; ++s) { + float_t lpEmit = probs[t][s]; + + if (lpEmit < minProb) { + continue; + } + + for (const BeamItem &prevNode : beam) { + Prefix *prevPrefix = prevNode.first; + const PrefixScore &prevScore = prevNode.second; + + // Ignore already completed paths + if (prevPrefix->Token == 1) { + continue; + } + + // Ignore impossible paths + if (prevScore.lProbBlank == NEG_INF && prevScore.lProbChar == NEG_INF) { + continue; + } + + // If we propose a blank the prefix doesn't change. + // Only the probability of ending in blank gets updated. + if (s == blank) { + PrefixScore &score = nextBeam[prevPrefix]; + score.lProbBlank = log_sum_exp(score.lProbBlank , prevScore.lProbBlank + lpEmit, prevScore.lProbChar + lpEmit); + // score.raw_lProbBlank = log_sum_exp(score.raw_lProbBlank, prevScore.raw_lProbBlank + lpEmit, prevScore.raw_lProbChar + lpEmit); + continue; + } + + // Extend the prefix by the new character s and add it to the beam. + // Only the probability of not ending in blank gets updated. + token_t prevToken = prevPrefix->Token; + + // NOTE: We always create a new prefix regardless of duplication because the PrefixScore + // is simultaneously tracking prefixes that do and don't end in a blank. And it's those + // that end in a blank that would cause the prefix to be extended. + auto extendPrefix = prefixAlloc.GetPrefix(s, prevPrefix); + + // Evaluate the language model, but use the cache if we've already considered this string before + auto lmCacheItem = make_tuple(prevPrefix, s); + auto lmCacheIter = lmScoreCache.find(lmCacheItem); + float_t lpLang = 0; + if (lmCacheIter == lmScoreCache.end()) { + lpLang = langModel.ScoreTransition(prevPrefix, s); + lpLang *= lmWeight; + lmCacheIter = lmScoreCache.emplace(lmCacheItem, lpLang).first; + } + lpLang = lmCacheIter->second; + + PrefixScore &extendScore = nextBeam[extendPrefix]; + // Remember, adding two log probabilities is equivalent to multiplying two probabilities + if (s != prevToken) { + extendScore.lProbChar = log_sum_exp(extendScore.lProbChar, prevScore.lProbBlank + lpEmit + lpLang, prevScore.lProbChar + lpEmit + lpLang); + // extendScore.raw_lProbChar = log_sum_exp(extendScore.raw_lProbChar, prevScore.raw_lProbBlank + lpEmit , prevScore.raw_lProbChar + lpEmit ); + } else { + // We don't include the previous probability of not ending in blank if s is repeated at the end. The CTC + // algorithm merges characters not separated by a blank. + extendScore.lProbChar = log_sum_exp(extendScore.lProbChar , prevScore.lProbBlank + lpEmit + lpLang); + // extendScore.raw_lProbChar = log_sum_exp(extendScore.raw_lProbChar, prevScore.raw_lProbBlank + lpEmit ); + } + + // If the token is repeated, we also have to deal with the unchanged prefix since repeated characters are collapsed + if (s == prevToken) { + PrefixScore &collapseScore = nextBeam[prevPrefix]; + collapseScore.lProbChar = log_sum_exp(collapseScore.lProbChar , prevScore.lProbChar + lpEmit); + // collapseScore.raw_lProbChar = log_sum_exp(collapseScore.raw_lProbChar, prevScore.raw_lProbChar + lpEmit); + } + + } + } + + Beam vecNextBeam(begin(nextBeam), end(nextBeam)); + + if (vecNextBeam.size() > beamSize) { + partial_sort(begin(vecNextBeam), begin(vecNextBeam) + beamSize, end(vecNextBeam), + [] (const BeamItem &a, const BeamItem &b) { + return a.second.get_lScore() > b.second.get_lScore(); + } + ); + vecNextBeam.resize(beamSize); + } + + beam = move(vecNextBeam); + } + + // Find the best raw score + const BeamItem *bestItem = nullptr; + // for (const BeamItem &b : beam) { + // if (bestItem == nullptr or b.second.get_raw_lScore() > bestItem->second.get_raw_lScore()) { + // bestItem = &b; + // } + // } + if (! beam.empty()) { + bestItem = &beam[0]; + } + + if (bestItem != nullptr) { + auto retList = bestItem->first->ToList(); + + return { move(retList), retScore }; + } else { + return { {}, NEG_INF }; + } +} + +typedef std::pair RegBeamItem; + +bool operator<(const RegBeamItem &a, const RegBeamItem &b) { + return a.second > b.second; +} + +template +pair, float_t> + reg_beam_decode_impl(const pred_seq_t &logProbs, const int64_t beamSize, + scalar_t minProb, + const LanguageModel &langModel, scalar_t lmWeight) +{ + const int64_t T = logProbs.size(0); + const int64_t S = logProbs.size(1); + + // NOTE: In log space, the following is true: + // 1. Adding two probabilities: log_sum_exp(l_p_a, l_p_b) + // 2. Multiplying two probabilities: l_p_a + l_p_b + // 3. log P(0) = -inf + // 4. log P(1) = 0 + + // Convert to log-space + if (minProb > 0) { + minProb = log(minProb); + } else { + minProb = NEG_INF; + } + + auto retScore = get_vision_confidence(logProbs, minProb); + + if (retScore < minProb) { + return { {}, NEG_INF }; + } + + PrefixAllocator prefixAlloc; + + vector beam, nextBeam; + beam.emplace_back(prefixAlloc.GetPrefix(), 0); // log(1) = 0 + + for (int64_t t = 0; t < T && !beam.empty(); ++t) { + nextBeam.clear(); + + auto addToBeam = [&nextBeam, beamSize] (const RegBeamItem &rbi) { + nextBeam.push_back(rbi); + }; + + // Expand each path in the beam + for (const RegBeamItem &prevNode : beam) { + if (prevNode.first->Token == 1) { + // Move completed paths along without processing further + addToBeam(prevNode); + continue; + } + + Prefix *prevPrefix = prevNode.first; + float_t prevScore = prevNode.second; + + // Loop over vocab + for (int64_t s = 0; s < S; ++s) { + float_t lpEmit = logProbs[t][s]; + + if (lpEmit < minProb) { + // The probability dropped below threshold, so stop processing this path + continue; + } + + auto extendPrefix = prefixAlloc.GetPrefix(s, prevPrefix); + + float_t lpLang = langModel.ScoreTransition(prevPrefix, s); + + float_t lpNext = prevScore + lpLang + lpEmit; + + addToBeam({extendPrefix, lpNext}); + } + } + + if (nextBeam.size() > beamSize) { + // Find the top-k items, and then truncate the rest + partial_sort(begin(nextBeam), begin(nextBeam) + beamSize, end(nextBeam)); + nextBeam.resize(beamSize); + } + + std::swap(beam, nextBeam); + } + + if (!beam.empty()) { + // The highest probability element will always be in the back + RegBeamItem rbi{ nullptr, NEG_INF }; + for (auto &rb : beam) { + if (rbi.first == nullptr || rb.second > rbi.second) { + rbi = rb; + } + } + + auto retList = rbi.first->ToList(); + + return { move(retList), retScore }; + } else { + return { {}, NEG_INF }; + } +} + + + +template +void dp_beam_decode_impl(const torch::TensorAccessor &probsAccess, + torch::TensorAccessor retAccess, + torch::TensorAccessor confAccess, + int64_t beamSize, int64_t blank, + scalar_t minProb, + const LanguageModel *langModel, + scalar_t lmWeight, + bool combineDuplicates) +{ + const int64_t N = probsAccess.size(0); + + #pragma omp parallel for num_threads(8) + for (int64_t i = 0; i < N; ++i) { + vector seq; + float_t lConf; + if (combineDuplicates) { + tie(seq, lConf) = ctc_beam_decode_impl(probsAccess[i], beamSize, blank, + minProb, + *langModel, lmWeight); + } else { + tie(seq, lConf) = reg_beam_decode_impl(probsAccess[i], beamSize, + minProb, + *langModel, lmWeight); + } + + int64_t sz = min(seq.size(), retAccess.size(1)); + + for (int64_t k = 0; k < sz; ++k) { + retAccess[i][k] = seq[k]; + } + + confAccess[i] = exp(lConf); + } +} + +std::tuple + beam_decode(torch::Tensor probs, int64_t beamSize, int64_t blank, + float minProb, + const LanguageModel *langModel, + float lmWeight, + bool combineDuplicates) +{ + if (langModel == nullptr) { + langModel = &NullLanguageModel; + } + + auto tStart = chrono::high_resolution_clock::now(); + + probs = probs.contiguous(); + + bool collapse = false; + if (probs.dim() == 2) { + // N,T,C + probs = probs.unsqueeze(0); + collapse = true; + } + + probs = probs.log(); + + torch::Tensor ret = torch::ones({ probs.size(0), probs.size(1) }, torch::kInt64); + torch::Tensor conf = torch::zeros({ probs.size(0) }, probs.options()); + + auto retAccess = ret.accessor(); + + AT_DISPATCH_FLOATING_TYPES( + probs.scalar_type(), + "cpu_beam_decode", + ([&] { + dp_beam_decode_impl( + probs.accessor(), + retAccess, + conf.accessor(), + beamSize, blank, + static_cast(minProb), + langModel, + static_cast(lmWeight), + combineDuplicates + ); + }) + ); + + if (collapse) { + ret = ret.squeeze(0); + conf = conf[0]; + } + + auto tEnd = chrono::high_resolution_clock::now(); + + typedef chrono::duration tp_t; + tp_t totalElapsed = tEnd - tStart; + + cout << "Beam Decode " << probs.size(0) << " - " + << "Total: " << totalElapsed.count() << "ms" + << endl; + + return { ret, conf }; +} + +std::unique_ptr create_sbo_lm(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoffWeight) +{ + return make_unique(dataFilePath, move(tokenMapping), backoffWeight); +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/beam_decode.h b/nemo-retriever-ocr/cpp/beam_decode/beam_decode.h new file mode 100644 index 0000000000000000000000000000000000000000..70e10273e205c3a2ad426e26037f321860c0ab7f --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/beam_decode.h @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "language_model.h" + +std::tuple + beam_decode(torch::Tensor probs, int64_t beamSize, int64_t blank, + float minProb, + const LanguageModel *langModel, + float lmWeight, + bool combineDuplicates); + +std::unique_ptr create_sbo_lm(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoffWeight); diff --git a/nemo-retriever-ocr/cpp/beam_decode/kn_lm.cpp b/nemo-retriever-ocr/cpp/beam_decode/kn_lm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a6b645800380770ab4c81d67c8980c5bfba34d1 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/kn_lm.cpp @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "kn_lm.h" + +using namespace std; + + +KN_LanguageModel::KN_LanguageModel(const string &dataFilePath, token_mapping_t tokenMapping, float_t knDelta) + : NGramLMBase(dataFilePath, move(tokenMapping)), m_knDelta(knDelta) +{ +} + +float KN_LanguageModel::ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const +{ + if (prefix.empty()) { + return ScoreUnigram(suffix); + } else { + return ScoreTransition(prefix, suffix); + } +} + +float_t KN_LanguageModel::ScoreUnigram(const std::wstring &uni) const +{ + auto lIter = m_lookup[1].find(L""s); + if (lIter == m_lookup[1].end()) { + throw std::runtime_error("Unigrams not supported by this model!"); + } + + auto uniIter = lIter->second.find(uni); + float_t ctUni = 1e-8; + if (uniIter != lIter->second.end()) { + ctUni = uniIter->second; + } + + float_t ctSuffixes = GetPrefixSum(L""s); + + return ctUni / ctSuffixes; +} + +float_t KN_LanguageModel::ScoreTransition(const std::wstring &prefix, const std::wstring &suffix) const +{ + if (prefix.empty()) { + // The number of distinct bigrams that end with this token + auto rlIter = m_reverseLookup.find(suffix); + + float_t ctEndingBigrams = 0; + if (rlIter != m_reverseLookup.end()) { + ctEndingBigrams = rlIter->second[2].size(); + } + + float_t ctAllBigrams = m_lookup[2].size(); + + return ctEndingBigrams / ctAllBigrams; + } + + auto lIter = m_lookup[prefix.size() + 1].find(prefix); + float_t ctUqSuffixes = 0; + float_t ctSuffixes = 0; + float_t ctSuffix = 0; + if (lIter != m_lookup[prefix.size() + 1].end()) { + ctUqSuffixes = lIter->second.size(); + + ctSuffixes = GetPrefixSum(prefix); + + auto sIter = lIter->second.find(suffix); + if (sIter != lIter->second.end()) { + ctSuffix = sIter->second; + } + } + + float_t factor = 0; + float_t main = 0; + if (ctSuffixes != 0) { + factor = m_knDelta * ctUqSuffixes / ctSuffixes; + // TODO: Figure out how to make this call without copying the string! + factor *= ScoreTransition({begin(prefix) + 1, end(prefix)}, suffix); + + main = max(ctSuffix - m_knDelta, 0) / ctSuffixes; + } + + float_t total = main + factor; + + return total; +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/kn_lm.h b/nemo-retriever-ocr/cpp/beam_decode/kn_lm.h new file mode 100644 index 0000000000000000000000000000000000000000..cb59c3f7673b67b22a1fa37d42413853637fb7c3 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/kn_lm.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ngram_lm_base.h" + + +class KN_LanguageModel + : public NGramLMBase +{ +public: + KN_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t knDelta); + +protected: + virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const override; + +private: + float_t ScoreUnigram(const std::wstring &uni) const; + float_t ScoreTransition(const std::wstring &prefix, const std::wstring &suffix) const; + + float_t m_knDelta; +}; diff --git a/nemo-retriever-ocr/cpp/beam_decode/language_model.cpp b/nemo-retriever-ocr/cpp/beam_decode/language_model.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d69596e0d6c6439e9e129de81cfbf876fc98c8a --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/language_model.cpp @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "language_model.h" + +#include +#include + +using namespace std; + +const NullLanguageModel_t NullLanguageModel; + +NullLanguageModel_t::NullLanguageModel_t() + : LanguageModel({}) +{ +} + +TokenMappingWrapper::TokenMappingWrapper(token_mapping_t mapping) + : token_mapping(move(mapping)) +{ + for (const auto &mp : token_mapping) { + if (mp.second.size() == 1) { + wchar_t c = mp.second.front(); + reverse_token_mapping.emplace(c, mp.first); + } + } +} + +TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping) +{ + return make_shared(move(tokenMapping)); +} + + +template +vector> + decode_sequences_impl(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping, + c10::optional probs) +{ + const token_mapping_t &mapping = tokenMapping->token_mapping; + + auto tokensAccess = tokens.accessor(); + + torch::Tensor pTens = probs.value_or(torch::ones({ tokens.size(0) }, torch::kFloat32)); + if (pTens.dim() == 1) { + pTens = pTens.unsqueeze(1); + } + + auto probsAccess = pTens.accessor(); + + const int64_t B = tokens.size(0); + const int64_t T = tokens.size(1); + + vector> ret; + + for (int64_t b = 0; b < B; ++b) { + wstring buff; + + float logProb = 0.0f; // log 1 + bool done = false; + for (int64_t t = 0; t < T && ! done; ++t) { + typename token_mapping_t::key_type tokIdx = tokensAccess[b][t]; + + if (t < probsAccess.size(1)) { + logProb += log(probsAccess[b][t]); + } + + switch (tokIdx) { + case 0: + // Blank char + continue; + case 1: + // End of sequence char + done = true; + break; + case 2: + buff.push_back('^'); + break; + default: + auto iter = mapping.find(tokIdx); + if (iter == mapping.end()) { + throw std::runtime_error("The token mapping doesn't contain an entry for index " + to_string(tokIdx)); + } + buff += iter->second; + break; + } + } + + ret.emplace_back(move(buff), exp(logProb)); + } + + return ret; +} + +vector> + decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping, + c10::optional probs) +{ + if (tokens.dim() != 2) { + throw std::runtime_error("`tokens` must be 2-dimensions of type B,T!"); + } + + if (tokenMapping == nullptr) { + throw std::runtime_error("Cannot supply a null token mapping!"); + } + + const token_mapping_t &mapping = tokenMapping->token_mapping; + + if (mapping.empty()) { + throw std::runtime_error("The token mapping hasn't been initialized!"); + } + + if (probs.has_value()) { + if (probs.value().scalar_type() != torch::kFloat32) { + throw std::runtime_error("If the probability distribution is specified, then it must be of type `torch.float32`"); + } + if (probs.value().size(0) != tokens.size(0)) { + throw std::runtime_error("The probability distribution batch size doesn't match the tokens batch size!"); + } + if (probs.value().dim() == 2 && probs.value().size(1) != tokens.size(1)) { + throw std::runtime_error("Invalid probability distribution shape!"); + } + } + + vector> ret; + + AT_DISPATCH_INTEGRAL_TYPES( + tokens.scalar_type(), + "decode_sequences_impl", + ([&] { + ret = decode_sequences_impl(tokens, tokenMapping, probs); + }) + ); + + return ret; +} + + +std::string ws2s(const std::wstring& wstr) +{ + using convert_typeX = std::codecvt_utf8; + std::wstring_convert converterX; + + return converterX.to_bytes(wstr); +} + diff --git a/nemo-retriever-ocr/cpp/beam_decode/language_model.h b/nemo-retriever-ocr/cpp/beam_decode/language_model.h new file mode 100644 index 0000000000000000000000000000000000000000..1aeeea0998075ac2808fbfdbb7aded077eca2e04 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/language_model.h @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "prefix.h" +#include "log_sum_exp.h" + +typedef std::unordered_map token_mapping_t; +typedef std::unordered_map reverse_token_mapping_t; + + +class LanguageModel +{ +public: + virtual ~LanguageModel() {} + + virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const = 0; + + const token_mapping_t &TokenMapping() const { return m_tokenMapping; } + +protected: + LanguageModel(token_mapping_t tokenMapping) + : m_tokenMapping(std::move(tokenMapping)) + {} + + token_mapping_t m_tokenMapping; +}; + + +class NullLanguageModel_t + : public LanguageModel +{ +public: + NullLanguageModel_t(); + + virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const override + { + // log P(1) + // Which means the probability is unchanged + return 0; + } +}; + +extern const NullLanguageModel_t NullLanguageModel; + +struct TokenMappingWrapper +{ + typedef std::shared_ptr Ptr; + + TokenMappingWrapper(token_mapping_t mapping); + + token_mapping_t token_mapping; + reverse_token_mapping_t reverse_token_mapping; +}; + +TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping); + +std::vector> + decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping, + c10::optional probs = torch::nullopt); diff --git a/nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.cpp b/nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..07529620b225f5d96bb82955fae6ddb0ccf67851 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.cpp @@ -0,0 +1,7 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "log_sum_exp.h" + +const float_t NEG_INF = -std::numeric_limits::infinity(); diff --git a/nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.h b/nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.h new file mode 100644 index 0000000000000000000000000000000000000000..074837b0f5dbef94995aa342e78f28b84f3f0541 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/log_sum_exp.h @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +typedef float float_t; +extern const float_t NEG_INF; + +template +inline T max_val(T v) +{ + return v; +} + +template +inline T max_val(T v, Args... rest) +{ + auto restMax = max_val(rest...); + + return std::max(v, restMax); +} + +template +inline T sum_exp(T maxVal, T v) +{ + return std::exp(v - maxVal); +} + +template +inline T sum_exp(T maxVal, T v, Args... rest) +{ + auto restSum = sum_exp(maxVal, rest...); + + return sum_exp(maxVal, v) + restSum; +} + +template +inline T log_sum_exp(T v, Args ...args) +{ + auto maxVal = max_val(v, args...); + + if (maxVal == -std::numeric_limits::infinity()) { + return -std::numeric_limits::infinity(); + } + + auto sumExp = sum_exp(maxVal, v, args...); + + return maxVal + std::log(sumExp); +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.cpp b/nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57fbee815394216ad05c94620001fd45fa380ee2 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.cpp @@ -0,0 +1,330 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "ngram_lm_base.h" + +#include +#include + +#if defined( USE_BOOST ) + +#include +#include +#include +#include +#include + +#endif // USE_BOOST + +using namespace std; + +const std::wstring WORD_END(1, 2); +const std::wstring NUMERIC(1, 3); +const std::wstring UNMODELED(1, 4); + +struct LMStorage +{ + lookup_t Lookup; + reverse_lookup_t ReverseLookup; + + template + void serialize(Archive &ar, const unsigned int version) { + ar & Lookup; + ar & ReverseLookup; + } +}; + +void save_suffix_map(std::fstream& fs, const suffix_map_t& suffix_map) +{ + // write out number of elements for Lookup + std::size_t suffix_map_count = suffix_map.size(); + fs.write((char*)(&suffix_map_count), sizeof(suffix_map_count)); + for (suffix_map_t::const_iterator reverse_lookup_it = suffix_map.begin(); reverse_lookup_it != suffix_map.end(); ++reverse_lookup_it) + { + // write out the key + size_t key_len = reverse_lookup_it->first.length(); + fs.write((char*)(&key_len), sizeof(key_len)); + fs.write((char*)(reverse_lookup_it->first.data()), key_len * sizeof(wchar_t)); + + // write out value + fs.write((char*)(&reverse_lookup_it->second), sizeof(reverse_lookup_it->second)); + } +} + +void save_lookup(std::fstream& fs, const lookup_t& lookup) +{ + // write out number of elements for Lookup + std::size_t lookup_count = lookup.size(); + fs.write((char*)(&lookup_count), sizeof(lookup_count)); + for (lookup_t::const_iterator lookup_it = lookup.begin(); lookup_it != lookup.end(); ++lookup_it) + { + // write out element map size + std::size_t map_elem_count = lookup_it->size(); + fs.write((char*)(&map_elem_count), sizeof(map_elem_count)); + + for (string_suffix_map_t::const_iterator str_sfx_it = lookup_it->begin(); str_sfx_it != lookup_it->end(); ++str_sfx_it) + { + // write out key + size_t key_len = str_sfx_it->first.length(); + fs.write((char*)(&key_len), sizeof(key_len)); + fs.write((char*)(str_sfx_it->first.data()), key_len * sizeof(wchar_t)); + save_suffix_map(fs, str_sfx_it->second); + } + } +} + +void save_reverse_lookup(std::fstream& fs, const reverse_lookup_t& reverse_lookup) +{ + // write out number of elements for Lookup + std::size_t reverse_lookup_count = reverse_lookup.size(); + fs.write((char*)(&reverse_lookup_count), sizeof(reverse_lookup_count)); + for (reverse_lookup_t::const_iterator reverse_lookup_it = reverse_lookup.begin(); reverse_lookup_it != reverse_lookup.end(); ++reverse_lookup_it) + { + // write out the key + size_t key_len = reverse_lookup_it->first.length(); + fs.write((char*)(&key_len), sizeof(key_len)); + fs.write((char*)(reverse_lookup_it->first.data()), key_len * sizeof(wchar_t)); + + // write out value vector length + size_t val_vec_len = reverse_lookup_it->second.size(); + fs.write((char*)(&val_vec_len), sizeof(val_vec_len)); + + for (suffix_map_vec_t::const_iterator val_vec_it = reverse_lookup_it->second.begin(); + val_vec_it != reverse_lookup_it->second.end(); + ++val_vec_it) + { + save_suffix_map(fs, *val_vec_it); + } + } +} + +void load_suffix_map(std::fstream& fs, suffix_map_t& suffix_map) +{ + // read in number of elements + std::size_t suffix_map_count = 0; + fs.read((char*)(&suffix_map_count), sizeof(suffix_map_count)); + for (size_t suffix_map_index = 0; suffix_map_index < suffix_map_count; ++suffix_map_index ) + { + // read in key + std::size_t key_len = 0; + fs.read((char*)(&key_len), sizeof(key_len)); + + std::wstring wkey(key_len, 0); + fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t)); + uint32_t value = 0; + fs.read((char*)(&value), sizeof(value)); + + suffix_map.insert(std::make_pair(wkey, value)); + } +} + +void load_lookup(std::fstream& fs, lookup_t& lookup) +{ + // read in number of elements + std::size_t lookup_count = 0; + fs.read((char*)(&lookup_count), sizeof(lookup_count)); + for (size_t lookup_index = 0; lookup_index < lookup_count; ++lookup_index) + { + std::size_t map_elem_count = 0; + fs.read((char*)(&map_elem_count), sizeof(map_elem_count)); + + lookup.push_back(string_suffix_map_t()); + string_suffix_map_t& str_sfx_map = lookup.back(); + + for (size_t str_sfx_map_index = 0; str_sfx_map_index < map_elem_count; ++str_sfx_map_index) + { + std::size_t key_len = 0; + fs.read((char*)(&key_len), sizeof(key_len)); + + std::wstring wkey(key_len, 0); + fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t)); + str_sfx_map.insert(std::make_pair(std::wstring(wkey), suffix_map_t())); + suffix_map_t& suffix_map = str_sfx_map[wkey]; + + load_suffix_map(fs, suffix_map); + } + } +} + +void load_reverse_lookup(std::fstream& fs, reverse_lookup_t& reverse_lookup) +{ + // read in number of elements + std::size_t reverse_lookup_count = 0; + fs.read((char*)(&reverse_lookup_count), sizeof(reverse_lookup_count)); + for (size_t rev_lookup_index = 0; rev_lookup_index < reverse_lookup_count; ++rev_lookup_index ) + { + // read in the key + std::size_t key_len = 0; + fs.read((char*)(&key_len), sizeof(key_len)); + + std::wstring wkey(key_len, 0); + fs.read((char*)(wkey.data()), key_len * sizeof(wchar_t)); + reverse_lookup.insert(std::make_pair(wkey, suffix_map_vec_t())); + suffix_map_vec_t& val_vec = reverse_lookup[wkey]; + + std::size_t val_vec_len = 0; + fs.read((char*)(&val_vec_len), sizeof(val_vec_len)); + + for (size_t val_vec_index = 0; val_vec_index < val_vec_len; ++val_vec_index) + { + val_vec.push_back(suffix_map_t()); + suffix_map_t& suffix_map = val_vec.back(); + load_suffix_map(fs, suffix_map); + } + } +} + +#if ! defined( USE_BOOST ) + +NGramLMBase::NGramLMBase(const string &dataFilePath, token_mapping_t tokenMapping) + : LanguageModel(move(tokenMapping)) +{ + std::fstream in(dataFilePath, std::ios::in | std::ios::binary); + load_lookup(in, m_lookup); + load_reverse_lookup(in, m_reverseLookup); + + if (m_lookup.size() >= 10) { + throw runtime_error("Only N-Grams of 9 or less are supported!"); + } + + for (auto &ngLevel : m_lookup) { + for (auto &kvPrefixLevel : ngLevel) { + uint32_t ct = 0; + for (auto &kvSfx : kvPrefixLevel.second) { + ct += kvSfx.second; + } + m_prefixSumLookup.emplace(kvPrefixLevel.first, ct); + } + } +} + +void save_ngram_data_file(const lookup_t& lookup, const reverse_lookup_t& reverseLookup, const std::string &outputPath) +{ + std::fstream out(outputPath, std::ios::out | std::ios::binary); + + save_lookup(out, lookup); + save_reverse_lookup(out, reverseLookup); +} + +#else // USE_BOOST + +NGramLMBase::NGramLMBase(const string &dataFilePath, token_mapping_t tokenMapping) + : LanguageModel(move(tokenMapping)) +{ + { + ifstream dfStr(dataFilePath, ios_base::in | ios_base::binary); + boost::archive::binary_iarchive ia(dfStr); + + LMStorage s; + ia >> s; + + + m_lookup = move(s.Lookup); + + m_reverseLookup = move(s.ReverseLookup); + } + + if (m_lookup.size() >= 10) { + throw runtime_error("Only N-Grams of 9 or less are supported!"); + } + + for (auto &ngLevel : m_lookup) { + for (auto &kvPrefixLevel : ngLevel) { + uint32_t ct = 0; + for (auto &kvSfx : kvPrefixLevel.second) { + ct += kvSfx.second; + } + m_prefixSumLookup.emplace(kvPrefixLevel.first, ct); + } + } +} + +void save_ngram_data_file(lookup_t lookup, reverse_lookup_t reverseLookup, const std::string &outputPath) +{ + ofstream ofs(outputPath, ios_base::out | ios_base::binary); + + LMStorage s; + s.Lookup = move(lookup); + s.ReverseLookup = move(reverseLookup); + + boost::archive::binary_oarchive oa(ofs); + oa << s; +} + +#endif // USE_BOOST + +float_t NGramLMBase::ScoreTransition(const Prefix *p, token_t nextToken) const +{ + std::wstring prefix; + if (! ConvertToString(p, prefix)) { + return NEG_INF; + } + + const std::wstring *pSuffix = nullptr; + + if (nextToken != 1) { + auto iter = m_tokenMapping.find(nextToken); + if (iter == m_tokenMapping.end()) { + pSuffix = &UNMODELED; + } else { + pSuffix = &iter->second; + + if (iswdigit(pSuffix->at(0))) { + pSuffix = &NUMERIC; + } + } + + } else { + pSuffix = &WORD_END; + } + + float_t ret = ScoreTransitionImpl(prefix, *pSuffix); + + if (ret > 0) { + return log(ret); + } else { + return NEG_INF; + } +} + +bool NGramLMBase::ConvertToString(const Prefix *p, std::wstring &prefix) const +{ + const Prefix *stk[10]; + int32_t sz = -1; + const Prefix *curr = p; + decltype(sz) mlSz{(int)m_lookup.size() - 2}; + while (curr && sz < mlSz) { + stk[++sz] = curr; + curr = curr->Parent; + } + + // Either blank or empty prefix + if (sz < 1) { return true; } + + --sz; + for (; sz >= 0; --sz) { + token_t tok = stk[sz]->Token; + // End of word token, which maps to the null character + if (tok == 1) { + prefix.push_back(WORD_END[0]); + } else if (tok == 0) { + // Do nothing + } else { + auto iter = m_tokenMapping.find(tok); + if (iter == m_tokenMapping.end()) { + prefix += UNMODELED; + } else { + const std::wstring &wChar = iter->second; + + if (iswdigit(wChar[0])) { + prefix += NUMERIC; + } else { + prefix += wChar; + } + } + } + } + + return true; +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.h b/nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.h new file mode 100644 index 0000000000000000000000000000000000000000..2a4a8cf73ca4405dd25bb4e25d002b58f0fc1fed --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/ngram_lm_base.h @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "language_model.h" + +// #define USE_BOOST 1 + + +typedef std::unordered_map suffix_map_t; + +/* Tells us the number of suffixes for a given ngram of order K + Keys: + 1. NGram Order + 2. Prefix + 3. Suffix +Value: + Count +*/ +typedef std::unordered_map string_suffix_map_t; +typedef std::vector lookup_t; +/* Tells us the number of K-gram prefixes found for a given suffix + Keys: + 1. Suffix + 2. NGram Order + 3. Prefix +Values: + Count +*/ +typedef std::vector suffix_map_vec_t; +typedef std::unordered_map reverse_lookup_t; + + + +extern const std::wstring WORD_END; +extern const std::wstring NUMERIC; +extern const std::wstring UNMODELED; + +class NGramLMBase + : public LanguageModel +{ +public: + virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const override; + +protected: + NGramLMBase(const std::string &dataFilePath, token_mapping_t tokenMapping); + + virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const = 0; + + bool ConvertToString(const Prefix *p, std::wstring &prefix) const; + + float_t GetPrefixSum(const std::wstring &prefix) const; + + lookup_t m_lookup; + reverse_lookup_t m_reverseLookup; + + std::unordered_map m_prefixSumLookup; +}; + +#if ! defined( USE_BOOST ) +void save_ngram_data_file(const lookup_t& lookup, const reverse_lookup_t& reverseLookup, const std::string &output_path); +#else // USE_BOOST +void save_ngram_data_file(lookup_t lookup, reverse_lookup_t reverseLookup, const std::string &output_path); +#endif // USE_BOOST + +inline float_t NGramLMBase::GetPrefixSum(const std::wstring &prefix) const +{ + auto iter = m_prefixSumLookup.find(prefix); + + if (iter == m_prefixSumLookup.end()) { + return 0; + } else { + return iter->second; + } +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/prefix.cpp b/nemo-retriever-ocr/cpp/beam_decode/prefix.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6bf04598e66372c4939bb797b0c57d83f5c7924 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/prefix.cpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "prefix.h" + +using namespace std; + +vector Prefix::ToList() const +{ + vector ret; + + auto curr = this; + + while (curr) { + if (curr->Token != 0) { + ret.push_back(curr->Token); + } + curr = curr->Parent; + } + + return { rbegin(ret), rend(ret) }; +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/prefix.h b/nemo-retriever-ocr/cpp/beam_decode/prefix.h new file mode 100644 index 0000000000000000000000000000000000000000..d9983df8919f33de75119a6973d9c01c6ec03a86 --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/prefix.h @@ -0,0 +1,158 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include + +typedef int32_t token_t; + +class Prefix; + +// typedef std::shared_ptr PrefixPtr; + +class Prefix +{ +public: + token_t Token; + Prefix *Parent; + + Prefix(token_t token = 0 /* blank */, Prefix *parent = nullptr) + : Token(token), Parent(parent) + {} + + std::vector ToList() const; + + size_t size() const; +}; + + +///// Borrowed from Boost libraries +template +void hash_combine(size_t & seed, T const& v) +{ + seed ^= std::hash()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} +///// + +namespace std { +template<> +struct hash +{ + size_t operator()(const Prefix *p) const noexcept + { + size_t seed = 0; + + while (p) { + if (p->Token != 0) { + hash_combine(seed, p->Token); + } + p = p->Parent; + } + return seed; + } +}; + +template<> +struct hash> +{ + size_t operator()(const tuple &t) const noexcept + { + size_t seed = 0; + hash_combine(seed, get<0>(t)); + hash_combine(seed, get<1>(t)); + return seed; + } +}; + +template<> +struct equal_to +{ + bool operator()(const Prefix *a, const Prefix *b) const noexcept + { + while (a != nullptr && b != nullptr) { + if (a->Token != b->Token) { + return false; + } + a = a->Parent; + b = b->Parent; + } + // If one chain is shorter than the other + return a == b; + } +}; +} + +inline size_t Prefix::size() const +{ + size_t ret = 0; + auto p = this; + while (p != nullptr) { + ret += 1; + p = p->Parent; + } + return ret; +} + + +class PrefixAllocator +{ +public: + PrefixAllocator() = default; + ~PrefixAllocator(); + + template + Prefix *GetPrefix(Args&& ...ctorArgs); + +private: + void AllocateNextBuffer(); + + std::list m_buffers; + size_t m_allocSize = 0; + size_t m_currOff = 0; +}; + +inline PrefixAllocator::~PrefixAllocator() +{ + for (auto p : m_buffers) { + // Prefix is a POD, and are allocated without initializing + // to prevent redundant work upfront + // delete[] p; + free(p); + } +} + +inline void PrefixAllocator::AllocateNextBuffer() +{ + size_t nextSize = m_allocSize == 0 ? 1000 : 2 * m_allocSize; + + // Using malloc here to prevent the ctor of Prefix being called for each item. + // Instead, the ctor will be called upon first access using GetPrefix + auto pBuff = reinterpret_cast(malloc(sizeof(Prefix) * nextSize)); + + m_buffers.push_back(pBuff); + + m_allocSize = nextSize; + m_currOff = 0; +} + +template +Prefix *PrefixAllocator::GetPrefix(Args&& ...ctorArgs) +{ + if (m_currOff == m_allocSize) { + AllocateNextBuffer(); + } + + auto buff = m_buffers.back() + m_currOff; + + auto ret = new (buff) Prefix(std::forward(ctorArgs)...); + + ++m_currOff; + + return ret; +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/sbo_lm.cpp b/nemo-retriever-ocr/cpp/beam_decode/sbo_lm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7b2a68556cefbcd42773eab135812696071852cc --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/sbo_lm.cpp @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "sbo_lm.h" + +#include + +// Reference paper: https://www.aclweb.org/anthology/D07-1090.pdf + + +SBO_LanguageModel::SBO_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoff) + : NGramLMBase(dataFilePath, move(tokenMapping)), m_backoff(backoff) +{ +} + +float SBO_LanguageModel::ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const +{ + auto lIter = m_lookup[prefix.size() + 1].find(prefix); + + // This prefix doesn't exist. Shrink it! + if (lIter == m_lookup[prefix.size() + 1].end()) { + return m_backoff * ScoreTransitionImpl({ begin(prefix) + 1, end(prefix) }, suffix); + } + + const suffix_map_t &suffixMap = lIter->second; + + auto sfIter = suffixMap.find(suffix); + + if (sfIter == suffixMap.end()) { + // This is a novel character entirely! + if (prefix.empty()) { + return 1e-8; + } else { + return m_backoff * ScoreTransitionImpl({ begin(prefix) + 1, end(prefix) }, suffix); + } + } + + float_t ctSuffix = sfIter->second; + float_t ctNgram = GetPrefixSum(prefix); + + float_t score = ctSuffix / ctNgram; + + assert(score >= 0 && score <= 1); + + return score; +} diff --git a/nemo-retriever-ocr/cpp/beam_decode/sbo_lm.h b/nemo-retriever-ocr/cpp/beam_decode/sbo_lm.h new file mode 100644 index 0000000000000000000000000000000000000000..3fe48d73833520db8883e9bde7d5f4c27280431c --- /dev/null +++ b/nemo-retriever-ocr/cpp/beam_decode/sbo_lm.h @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "kn_lm.h" + + +class SBO_LanguageModel + : public NGramLMBase +{ +public: + SBO_LanguageModel(const std::string &dataFilePath, token_mapping_t tokenMapping, float_t backoff); + +protected: + virtual float_t ScoreTransitionImpl(const std::wstring &prefix, const std::wstring &suffix) const override; + +private: + float_t m_backoff; +}; diff --git a/nemo-retriever-ocr/cpp/better_grid_sample/cpu_indirect_grid_sample.cpp b/nemo-retriever-ocr/cpp/better_grid_sample/cpu_indirect_grid_sample.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c1c38ef30b9365eedb50fa9dcc606f703346e84 --- /dev/null +++ b/nemo-retriever-ocr/cpp/better_grid_sample/cpu_indirect_grid_sample.cpp @@ -0,0 +1,94 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "grid_sample.h" +#include "gpu_grid_sample_utils.cuh" + +template +void indirect_grid_sample_forward_bilinear(torch::TensorAccessor input, + torch::TensorAccessor grid, + torch::TensorAccessor inputIndices, + torch::TensorAccessor output) +{ + const int64_t N = inputIndices.size(0); + const int64_t C = output.size(1); + + T fInputHeight = input.size(2); + T fInputWidth = input.size(3); + int64_t outputHeight = output.size(2); + int64_t outputWidth = output.size(3); + + #pragma omp parallel for num_threads(8) + for (int64_t i = 0; i < N; ++i) { + int64_t inputIdx = inputIndices[i]; + + for (int64_t c = 0; c < C; ++c) { + for (int64_t outY = 0; outY < outputHeight; ++outY) { + for (int64_t outX = 0; outX < outputWidth; ++outX) { + T u = grid[i][outY][outX][0]; + T v = grid[i][outY][outX][1]; + + if (u < -1 || u > 1 || v < -1 || v > 1) { + output[i][c][outY][outX] = 0; + continue; + } + + // Denormalize the coordinates + u = (u + 1) * ((fInputWidth - 1) / 2); + v = (v + 1) * ((fInputHeight - 1) / 2); + + // Calculate coordinates + const T inX = u; + const T inXint = std::floor(inX); + const T inXfrac = inX - inXint; + + const T inY = v; + const T inYint = std::floor(inY); + const T inYfrac = inY - inYint; + + T ps[] = { 1 - inXfrac, inXfrac }; + T rs[] = { 1 - inYfrac, inYfrac }; + T opVal = 0; + + #pragma unroll + for (int64_t row = 0; row < 2; ++row) { + #pragma unroll + for (int64_t col = 0; col < 2; ++col) { + T Tpx = utils::get_pixel_clamped(input, inputIdx, c, inXint + col, inYint + row); + opVal += rs[row] * ps[col] * Tpx; + } + } + + output[i][c][outY][outX] = opVal; + } + } + } + } +} + +torch::Tensor cpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, + torch::Tensor inputIndices, const std::string &method) +{ + auto output = input.new_empty({ inputIndices.size(0), input.size(1), grid.size(1), grid.size(2) }); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "cpu_indirect_grid_sample_forward_impl", + ([&] { + typedef scalar_t T; + if (method == "bilinear") { + indirect_grid_sample_forward_bilinear( + input.accessor(), + grid.accessor(), + inputIndices.accessor(), + output.accessor() + ); + } else { + throw std::runtime_error("Unsupported resample method: " + method); + } + }) + ); + + return output; +} diff --git a/nemo-retriever-ocr/cpp/better_grid_sample/gpu_grid_sample_utils.cuh b/nemo-retriever-ocr/cpp/better_grid_sample/gpu_grid_sample_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1bf528a914623c5321125234ddc5fb9f410675bb --- /dev/null +++ b/nemo-retriever-ocr/cpp/better_grid_sample/gpu_grid_sample_utils.cuh @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "../cuda_intellisense.cuh" + +#ifndef __NVCC__ +#include +#define __device__ +#endif + +namespace utils { + +#ifdef __NVCC__ + +template +__device__ __lib_inline__ +T clamp(T val, T minVal, T maxVal) +{ + return max(minVal, min(val, maxVal)); +} + +#else +using std::clamp; +#endif + +template +__device__ __lib_inline__ +auto &get_pixel_clamped(accessor_t &inputs, + int64_t n, int64_t c, int64_t x, int64_t y) +{ + x = clamp(x, 0, inputs.size(3) - 1); + y = clamp(y, 0, inputs.size(2) - 1); + + return inputs[n][c][y][x]; +} + +} diff --git a/nemo-retriever-ocr/cpp/better_grid_sample/gpu_indirect_grid_sample.cu b/nemo-retriever-ocr/cpp/better_grid_sample/gpu_indirect_grid_sample.cu new file mode 100644 index 0000000000000000000000000000000000000000..c56eb194fdcf37185120f5afbee358a88b7d3e17 --- /dev/null +++ b/nemo-retriever-ocr/cpp/better_grid_sample/gpu_indirect_grid_sample.cu @@ -0,0 +1,328 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "grid_sample.h" + +#include "../cuda_intellisense.cuh" +#include "../half_ops.cuh" +#include "gpu_grid_sample_utils.cuh" + +using namespace std; + +template +__device__ __lib_inline__ +auto &my_get_pixel_clamped(accessor_t &inputs, index_t x, index_t y) +{ + x = utils::clamp(x, 0, inputs.size(1) - 1); + y = utils::clamp(y, 0, inputs.size(0) - 1); + + return inputs[y][x]; +} + +__global__ +void single_ex_grid_sample_bilinear_kernel(const float *pInputImage, + uint32_t imgHeight, uint32_t imgWidth, uint32_t numChannels, + const float2 *pGrid, + uint32_t numGridCells, + float *pOutputImage) +{ + const uint32_t z = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t c = blockDim.y * blockIdx.y + threadIdx.y; + + if (c >= numChannels || z >= numGridCells) { + return; + } + + const uint32_t g = blockIdx.z; + + const float2 uv = pGrid[g * numGridCells + z]; + + float &outPx = pOutputImage[(g * numChannels + c) * numGridCells + z]; + if (abs(uv.x) > 1.0f || abs(uv.y) > 1.0f) { + outPx = 0.0f; + } else { + const uint32_t maxX = imgWidth - 1; + const uint32_t maxY = imgHeight - 1; + + const float u = (uv.x + 1.0f) * maxX * 0.5f; + const float v = (uv.y + 1.0f) * maxY * 0.5f; + + // calculate coordinates + const float inX = u; + const uint32_t inXint = inX; + const float inXfrac = inX - inXint; + + const float inY = v; + const uint32_t inYint = inY; + const float inYfrac = inY - inYint; + + const float *pChanImage = pInputImage + c * imgHeight * imgWidth; + + // By being in this conditional block, we know that u and v are >= 0, which means + // that their truncated value is also >= 0. Instead of clamping the value to within the buffer, + // we set the multiplication factor to be 0 if the interpolated value is outside the buffer + const float ps[] = { 1.0f - inXfrac, inXfrac * (inXint < maxX) }; + const float rs[] = { 1.0f - inYfrac, inYfrac * (inYint < maxY) }; + float opVal = 0.0f; + #pragma unroll + for (uint32_t row = 0; row < 2; ++row) { + const float *pRowImage = pChanImage + (inYint + row) * imgWidth; + + #pragma unroll + for (uint32_t col = 0; col < 2; ++col) { + const float px = pRowImage[inXint + col]; + opVal += rs[row] * ps[col] * px; + } + } + + outPx = opVal; + } +} + +template +__global__ +void indirect_grid_sample_forward_bilinear_kernel(torch::PackedTensorAccessor32 inputs, + torch::PackedTensorAccessor32 grid, + torch::PackedTensorAccessor32 inputIndices, + torch::PackedTensorAccessor32 outputs) +{ + static_assert(std::is_same::value, "Currently only float32 is supported!"); + //typedef typename fp_promote::type accum_t; + typedef float accum_t; + constexpr T NEG_ONE = -1; + constexpr T ONE = 1; + constexpr T ZERO = 0; + constexpr T TWO = 2; + constexpr T ZERO_PT_5 = 0.5; + typedef decltype(inputs.stride(0)) index_t; + + const index_t n = blockDim.z * blockIdx.z + threadIdx.z; + + if (n >= inputIndices.size(0)) return; + + const index_t c = blockDim.y * blockIdx.y + threadIdx.y; + + const index_t z = blockDim.x * blockIdx.x + threadIdx.x; + + const accum_t inputHeight = inputs.size(2); + const accum_t inputWidth = inputs.size(3); + const index_t outputHeight = outputs.size(2); + const index_t outputWidth = outputs.size(3); + + const index_t outY = z / outputWidth; + //const index_t outX = z % outputWidth; + const index_t outX = z - (outY * outputWidth); + + if (outY >= outputHeight) return; + + index_t inputIdx = inputIndices[n]; + const float2 f2uv = *reinterpret_cast(grid[n][outY][outX].data()); + float u = f2uv.x; + float v = f2uv.y; + + if (u < NEG_ONE || u > ONE || v < NEG_ONE || v > ONE) { + outputs[n][c][outY][outX] = ZERO; + return; + } + + // Denormalize the coordinates + u = (u + ONE) * ((inputWidth - ONE) * ZERO_PT_5); + v = (v + ONE) * ((inputHeight - ONE) * ZERO_PT_5); + + // calculate coordinates + const accum_t inX = u; + const index_t inXint = inX; + const accum_t inXfrac = inX - inXint; + + const accum_t inY = v; + const index_t inYint = inY; + const accum_t inYfrac = inY - inYint; + + accum_t ps[] = { ONE - inXfrac, inXfrac }; + accum_t rs[] = { ONE - inYfrac, inYfrac }; + accum_t opVal = ZERO; + + auto localInputs = inputs[inputIdx][c]; + + #pragma unroll + for (index_t row = 0; row < 2; ++row) { + #pragma unroll + for (index_t col = 0; col < 2; ++col) { + T Tpx = my_get_pixel_clamped(localInputs, inXint + col, inYint + row); + opVal += rs[row] * ps[col] * Convert::LeftToRight(Tpx); + } + } + + outputs[n][c][outY][outX] = Convert::RightToLeft(opVal); +} + +template +__global__ +void indirect_grid_sample_backward_bilinear_kernel(torch::PackedTensorAccessor64 inputs, + torch::PackedTensorAccessor64 grid, + torch::PackedTensorAccessor64 inputIndices, + torch::PackedTensorAccessor64 gradOutput, + torch::PackedTensorAccessor64 gradInput, + torch::PackedTensorAccessor64 gradGrid) +{ + typedef typename fp_promote::type accum_t; + constexpr T NEG_ONE = -1; + constexpr T ONE = 1; + + const int64_t n = blockDim.z * blockIdx.z + threadIdx.z; + + if (n >= inputIndices.size(0)) return; + + const int64_t c = blockDim.y * blockIdx.y + threadIdx.y; + + const int64_t z = blockDim.x * blockIdx.x + threadIdx.x; + + const accum_t inputHeight = inputs.size(2); + const accum_t inputWidth = inputs.size(3); + const int64_t outputHeight = gradOutput.size(2); + const int64_t outputWidth = gradOutput.size(3); + + const int64_t outY = z / outputWidth; + const int64_t outX = z % outputWidth; + + if (outY >= outputHeight) return; + + int64_t inputIdx = inputIndices[n]; + const float2 f2uv = *reinterpret_cast(grid[n][outY][outX].data()); + float u = f2uv.x; + float v = f2uv.y; + + // No output gradient contribution from this position + if (u < NEG_ONE || u > ONE || v < NEG_ONE || v > ONE) { + return; + } + + // Denormalize the coordinates + u = (u + 1) * ((inputWidth - 1) / 2); + v = (v + 1) * ((inputHeight - 1) / 2); + + // calculate coordinates + const accum_t inX = u; + const accum_t inXint = floor(inX); + const accum_t inXfrac = inX - inXint; + + const accum_t inY = v; + const accum_t inYint = floor(inY); + const accum_t inYfrac = inY - inYint; + + accum_t ps[] = { 1 - inXfrac, inXfrac }; + accum_t rs[] = { 1 - inYfrac, inYfrac }; + + const accum_t gOut = Convert::LeftToRight(gradOutput[n][c][outY][outX]); + + #pragma unroll + for (size_t row = 0; row < 2; ++row) { + #pragma unroll + for (size_t col = 0; col < 2; ++col) { + T &gIn = utils::get_pixel_clamped(gradInput, inputIdx, c, inXint + col, inYint + row); + + T gContrib = Convert::RightToLeft(rs[row] * ps[col] * gOut); + + atomicAdd(&gIn, gContrib); + } + } +} + +torch::Tensor gpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method) +{ + auto output = input.new_empty({ inputIndices.size(0), input.size(1), grid.size(1), grid.size(2) }); + + + if (method != "bilinear"s) { + throw runtime_error("Only 'bilinear' sampling is currently supported!"); + } + + if (input.size(0) == 1 && input.is_contiguous() && grid.is_contiguous()) { + uint32_t gridNumCells = grid.size(1) * grid.size(2); + dim3 blockDim(32, 3, 1); + dim3 gridDim(div_up(gridNumCells, blockDim.x), + div_up(input.size(1), blockDim.y), + div_up(grid.size(0), blockDim.z)); + single_ex_grid_sample_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) ( + input.data_ptr(), + input.size(2), input.size(3), input.size(1), + reinterpret_cast(grid.data_ptr()), + gridNumCells, + output.data_ptr() + ); + + } else { + // z is batch idx + // y is channel + // x is w*h + dim3 blockDim(32, 1, 3); + dim3 gridDim(div_up(grid.size(1) * grid.size(2), blockDim.x), + div_up(input.size(1), blockDim.y), + div_up(inputIndices.size(0), blockDim.z)); + indirect_grid_sample_forward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) ( + input.packed_accessor32(), + grid.packed_accessor32(), + inputIndices.packed_accessor32(), + output.packed_accessor32() + ); + } + + //AT_DISPATCH_FLOATING_TYPES_AND_HALF( + // input.scalar_type(), + // "gpu_indirect_grid_sample_forward", + // ([&] { + // typedef typename remap_half::type T; + // // typedef scalar_t T; + // if (method == "bilinear") { + // indirect_grid_sample_forward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) ( + // input.packed_accessor64(), + // grid.packed_accessor64(), + // inputIndices.packed_accessor64(), + // output.packed_accessor64() + // ); + // } else { + // throw runtime_error("Unsupported resample method: " + method); + // } + // }) + //); + + return output; +} + +std::vector gpu_indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method) +{ + auto gradInput = torch::zeros_like(input); + auto gradGrid = torch::zeros_like(grid); + + // z is batch idx + // y is channel + // x is w*h + dim3 blockDim(32, 1, 1); + dim3 gridDim(div_up(grid.size(1) * grid.size(2), blockDim.x), + div_up(input.size(1), blockDim.y), + div_up(inputIndices.size(0), blockDim.z)); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "gpu_indirect_grid_sample_backward", + ([&] { + typedef typename remap_half::type T; + // typedef scalar_t T; + if (method == "bilinear") { + indirect_grid_sample_backward_bilinear_kernel KERNEL_ARG2(gridDim, blockDim) ( + input.packed_accessor64(), + grid.packed_accessor64(), + inputIndices.packed_accessor64(), + gradOutput.packed_accessor64(), + gradInput.packed_accessor64(), + gradGrid.packed_accessor64() + ); + } else { + throw runtime_error("Unsupported resample method: " + method); + } + }) + ); + + return { gradInput, gradGrid }; +} diff --git a/nemo-retriever-ocr/cpp/better_grid_sample/grid_sample.h b/nemo-retriever-ocr/cpp/better_grid_sample/grid_sample.h new file mode 100644 index 0000000000000000000000000000000000000000..53f683063953b03f24dee77acf681d1df88ef294 --- /dev/null +++ b/nemo-retriever-ocr/cpp/better_grid_sample/grid_sample.h @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +inline +torch::Tensor region_counts_to_indices(torch::Tensor regionCounts, int64_t numOutputs) +{ + // If there's only one example, we can trivially return idx 0 for all + if (regionCounts.size(0) == 1) { + return torch::zeros({ numOutputs }, regionCounts.options().dtype(torch::kInt64)); + } + + // regionCounts will be some tensor like [ 5, 1, 10, 2 ] which means that the first 5 outputs + // correspond to the first input, the next output to the second input, 10 to the third, and so on. + + // We want to convert this to instead have an entry for each output which specifies the index of the corresponding input. + // To do this, we can count the number of times the output index exceeds the cumulative input counts. + // e.g. the cumulative region count for the above tensor is [ 5, 6, 16, 18 ]. + // The output indices 0-4 are not greater than or equal to any cumulative count, so they get the input index of 0. + // The output index 5 is equal to a single count, therefore index 1. + // The outputs 6-15 are all greater than or equal to two cumulative counts, therefore index 2. + // And so on. + + auto indices = torch::arange(regionCounts.size(0), regionCounts.options().dtype(torch::kInt64)); + + auto outputIndices = torch::repeat_interleave(indices, regionCounts, /*dim=*/ 0, /*output_size=*/ numOutputs); + + return outputIndices; +} + +torch::Tensor gpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method); +torch::Tensor cpu_indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method); +std::vector gpu_indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method); + +inline +torch::Tensor indirect_grid_sample_forward(torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method) +{ + if (input.is_cuda() != grid.is_cuda() || input.is_cuda() != inputIndices.is_cuda()) { + throw std::runtime_error("Input tensors must all be on the same device!"); + } + if (inputIndices.size(0) != grid.size(0)) { + throw std::runtime_error("The batch dimensions must match!"); + } + if (grid.size(-1) != 2) { + throw std::runtime_error("The final grid dimension must be 2."); + } + + if (input.is_cuda()) { + return gpu_indirect_grid_sample_forward(std::move(input), std::move(grid), std::move(inputIndices), method); + } else { + return cpu_indirect_grid_sample_forward(std::move(input), std::move(grid), std::move(inputIndices), method); + } +} + +inline +std::vector indirect_grad_sample_backward(torch::Tensor gradOutput, torch::Tensor input, torch::Tensor grid, torch::Tensor inputIndices, const std::string &method) +{ + if (gradOutput.is_cuda()) { + return gpu_indirect_grad_sample_backward(std::move(gradOutput), std::move(input), std::move(grid), std::move(inputIndices), method); + } else { + throw std::runtime_error("Not implemented!"); + } +} diff --git a/nemo-retriever-ocr/cpp/common.cpp b/nemo-retriever-ocr/cpp/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..096fef757e3c79628f2669ef29ac529efc01bb52 --- /dev/null +++ b/nemo-retriever-ocr/cpp/common.cpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "common.h" + +#include + +using namespace std; + +void print_tensor(const torch::Tensor &t) { + cout << t << endl; +} diff --git a/nemo-retriever-ocr/cpp/common.h b/nemo-retriever-ocr/cpp/common.h new file mode 100644 index 0000000000000000000000000000000000000000..682fc0127a5972845a3aa3f4072963a825bd80d5 --- /dev/null +++ b/nemo-retriever-ocr/cpp/common.h @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include + +template +inline +std::ostream &operator<<(std::ostream &os, const std::vector &v) { + os << "["; + if (! v.empty()) { + os << v[0]; + for (size_t i = 1; i < v.size(); ++i) { + os << ", " << v[i]; + } + } + os << "]"; + return os; +} + +template +struct _inner_tuple_print +{ + inline + static std::ostream &print(std::ostream &os, const std::tuple &t) { + _inner_tuple_print::print(os, t); + + os << ", " << std::get(t); + return os; + } +}; + +template +struct _inner_tuple_print<0, Args...> +{ + inline + static std::ostream &print(std::ostream &os, const std::tuple &t) { + os << std::get<0>(t); + return os; + } +}; + + +template +inline +std::ostream &operator<<(std::ostream &os, const std::tuple &t) { + os << "("; + _inner_tuple_print::print(os, t); + os << ")"; + return os; +} + +void print_tensor(const torch::Tensor &t); diff --git a/nemo-retriever-ocr/cpp/cuda_intellisense.cuh b/nemo-retriever-ocr/cpp/cuda_intellisense.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a7b396b69c4124192d9819b447023bfe641f799 --- /dev/null +++ b/nemo-retriever-ocr/cpp/cuda_intellisense.cuh @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#if defined(__INTELLISENSE__) || !defined(__NVCC__) +#ifndef KERNEL_ARG2 +#define KERNEL_ARG2(grid, block) +#define KERNEL_ARG3(grid, block, sh_mem) +#define KERNEL_ARG4(grid, block, sh_mem, stream) +#define __global__ +#define __device__ +#define __host__ +#endif +#endif + +#ifdef __INTELLISENSE__ +#define __CUDACC__ +#include + +void __syncthreads(); // workaround __syncthreads warning + +dim3 threadIdx; +dim3 blockIdx; +dim3 blockDim; +dim3 gridDim; + +#else +#ifndef KERNEL_ARG2 +#define KERNEL_ARG2(grid, block) <<< grid, block >>> +#define KERNEL_ARG3(grid, block, sh_mem) <<< grid, block, sh_mem >>> +#define KERNEL_ARG4(grid, block, sh_mem, stream) <<< grid, block, sh_mem, stream >>> +#endif +#endif + +#define __any_device__ __host__ __device__ + +#ifdef __NVCC__ +#define __lib_inline__ __forceinline__ + +#else +#define __lib_inline__ inline +#endif + +template +__any_device__ +inline auto div_up(T1 n, T2 d) +{ + return (n + d - 1) / d; +} diff --git a/nemo-retriever-ocr/cpp/geometry.h b/nemo-retriever-ocr/cpp/geometry.h new file mode 100644 index 0000000000000000000000000000000000000000..afac774a91a054bc0a932d8f918bde525e5b7fc7 --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry.h @@ -0,0 +1,1101 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#ifndef _GEOMETRY_NO_TORCH +#include +#endif + +#include "cuda_intellisense.cuh" + +#ifndef __NVCC__ +#define SORT_ALGO std::sort +#define SWAP std::swap + +template +using tuple_t = std::tuple; + +#else + +#include +#include + +#define SORT_ALGO thrust::sort +#define SWAP thrust::swap + +template +using tuple_t = thrust::tuple; +#endif + +template +struct Point_ { + typedef T inner_type; + + T X, Y; + + Point_() = default; + + __any_device__ + Point_(T x, T y) : X(x), Y(y) {} + + __any_device__ + Point_(T *ptr) : X(ptr[0]), Y(ptr[1]) {} + +#ifndef _GEOMETRY_NO_TORCH + template + __any_device__ + Point_(const torch::TensorAccessor &accessor) : X(accessor[0]), Y(accessor[1]) {} + + template + __any_device__ + Point_(const torch::PackedTensorAccessor64 &accessor) : X(accessor[0]), Y(accessor[1]) {} +#endif + + __any_device__ + Point_ &operator+=(const Point_ &other); + + __any_device__ + Point_ &operator-=(const Point_ &other); + + __any_device__ + Point_ &operator*=(const Point_ &other); + + __any_device__ + Point_ &operator/=(const Point_ &other); + + template + __any_device__ + Point_ &operator/=(W w); + + template + __any_device__ + Point_ &operator*=(W w); + + __any_device__ + Point_ operator-() { + return { -X, -Y }; + } + + __any_device__ + T Sum() const { return X + Y; } + + __any_device__ + T Angle() const; + + __any_device__ + void swap(Point_ &other) noexcept { + SWAP(X, other.X); + SWAP(Y, other.Y); + } +}; + +template +__lib_inline__ __any_device__ +void swap(Point_ &a, Point_ &b) { + a.swap(b); +} + + +template +__any_device__ +__lib_inline__ T Point_::Angle() const { +#ifndef __NVCC__ + using std::atan2; +#endif + return atan2(Y, X); +} + +template +__any_device__ +__lib_inline__ Point_ min(const Point_ &a, const Point_ &b) { +#ifndef __NVCC__ + using std::min; +#endif + return { + min(a.X, b.X), + min(a.Y, b.Y) + }; +} + +template +__any_device__ +__lib_inline__ Point_ max(const Point_ &a, const Point_ &b) { +#ifndef __NVCC__ + using std::max; +#endif + return { + max(a.X, b.X), + max(a.Y, b.Y) + }; +} + +template +struct AABB_ { + typedef T inner_type; + + T X; + T Y; + T MaxX; + T MaxY; + + AABB_() = default; + __any_device__ + AABB_(T x, T y, T maxX, T maxY) + : X(x), Y(y), MaxX(maxX), MaxY(maxY) {} + + __any_device__ + bool Contains(const Point_ &p) const { + return p.X >= X && p.X < MaxX && + p.Y >= Y && p.Y < MaxY; + } + + __any_device__ __lib_inline__ + AABB_ Union(const AABB_ &other) const { +#ifndef __NVCC__ + using std::min; + using std::max; +#endif + T minX = min(X, other.X); + T maxX = max(MaxX, other.MaxX); + T minY = min(Y, other.Y); + T maxY = max(MaxY, other.MaxY); + + return { minX, minY, maxX, maxY }; + } + + __any_device__ + AABB_ &operator-=(const Point_ &offset) { + X -= offset.X; + MaxX -= offset.X; + Y -= offset.Y; + MaxY -= offset.Y; + return *this; + } + + __any_device__ + __lib_inline__ T Width() const { return MaxX - X; } + __any_device__ + __lib_inline__ T Height() const { return MaxY - Y; } + __any_device__ + __lib_inline__ T Area() const { return Width() * Height(); } + + __lib_inline__ T &operator[] (int64_t idx) + { + static_assert(std::is_standard_layout>::value, "This function is only valid for standard layout"); + return (&X)[idx]; + } + __lib_inline__ T operator[] (int64_t idx) const + { + static_assert(std::is_standard_layout>::value, "This function is only valid for standard layout"); + return (&X)[idx]; + } + + __any_device__ __lib_inline__ + AABB_ Intersection(const AABB_ &other) const { +#ifndef __NVCC__ + using std::min; + using std::max; +#endif + T minX = max(X, other.X); + T minY = max(Y, other.Y); + T maxX = min(MaxX, other.MaxX); + T maxY = min(MaxY, other.MaxY); + // Prevent negative area + minX = min(minX, maxX); + minY = min(minY, maxY); + return { minX, minY, maxX, maxY }; + } + + __any_device__ __lib_inline__ + T IntersectionArea(const AABB_ &other) const { return Intersection(other).Area(); } +}; + +template +struct QuadBase_ { + typedef T inner_type; + + __any_device__ + AABB_ Bounds() const; + + __any_device__ + bool Contains(const Point_ &p) const; + + __any_device__ + T Area() const; + + __any_device__ + T Height() const; + + __any_device__ + T Width() const; + + template + __any_device__ + T IntersectionArea(const QuadBase_ &other) const; + + template + __any_device__ + T IOU(const QuadBase_ &other) const; + + template + __any_device__ + T IOU_UpperBound(const QuadBase_ &other) const; + + __any_device__ + Point_ Center() const; + + template + __any_device__ + /* + Returns 3 geometric associations between the two quads: + 0: The percent shared area between this and other relative to this (e.g. if other contains this, then it returns 1) + 1: The percent shared area between other and this relative to other (e.g. if this contains other, then it return 1) + 2: The IOU of the two quads + */ + tuple_t RegionSizes(const QuadBase_ &other) const; + + template + __any_device__ + tuple_t RegionSizes_UpperBound(const QuadBase_ &other) const; + + __any_device__ + Derived &operator/=(T val) { + auto rcp = 1 / val; + return *this *= rcp; + } + + __any_device__ + Derived &operator*=(T val) { + auto dThis = static_cast(this); + #pragma unroll + for (size_t i = 0; i < 4; ++i) { + dThis->Vertices[i] *= val; + } + return *dThis; + } + + friend auto begin(const QuadBase_ &q) { return static_cast(q).Vertices; } + friend auto begin(QuadBase_& q) { return static_cast(q).Vertices; } + friend auto end(const QuadBase_ &q) { return static_cast(q).Vertices + 4; } + friend auto end(QuadBase_ &q) { return static_cast(q).Vertices + 4; } +}; + +template +struct Quad_ : QuadBase_> { + Point_ *Vertices = nullptr; + + Quad_() = default; + __any_device__ + Quad_(T *dataPtr) + : Vertices(reinterpret_cast*>(dataPtr)) {} + __any_device__ + Quad_(Point_ *dataPtr) + : Vertices(dataPtr) {} + + template + __any_device__ __lib_inline__ + const Point_ &operator[](index_t offset) const { return Vertices[offset]; } + template + __any_device__ __lib_inline__ + Point_ &operator[](index_t offset) { return Vertices[offset]; } +}; + +template +struct InPlaceQuad_ : public QuadBase_> { + Point_ Vertices[4]; + + InPlaceQuad_() = default; + __any_device__ + InPlaceQuad_(const T *dataPtr) + { +#if defined(__NVCC__) + T *pVals = reinterpret_cast(Vertices); + #pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + pVals[i] = dataPtr[i]; + } +#else + using std::copy; + copy(dataPtr, dataPtr + 8, reinterpret_cast(Vertices)); +#endif + } + __any_device__ + InPlaceQuad_(const Point_ *dataPtr) + { +#if defined(__NVCC__) + #pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + Vertices[i] = dataPtr[i]; + } +#else + using std::copy; + copy(dataPtr, dataPtr + 4, Vertices); +#endif + } + + template + __any_device__ __lib_inline__ + const Point_ &operator[](index_t v) const { return Vertices[v]; } + + template + __any_device__ __lib_inline__ + Point_ &operator[](index_t v) { return Vertices[v]; } +}; + +template +struct PolygonBase_ { + typedef T inner_type; + + __any_device__ + AABB_ Bounds() const; + + __any_device__ + bool Contains(const Point_ &p) const; + + __any_device__ + T EdgeLength() const; + + __any_device__ + Point_ Center() const; + + __any_device__ + T Area() const; +}; + +template +struct Polygon_ : PolygonBase_> { + Point_ *Vertices = nullptr; + size_t Count = 0; + + Polygon_() = default; + __any_device__ + Polygon_(T *dataPtr, size_t vertexCount) + : Vertices(reinterpret_cast*>(dataPtr)), Count(vertexCount) {} + __any_device__ + Polygon_(Point_ *dataPtr, size_t vertexCount) + : Vertices(dataPtr), Count(vertexCount) {} + + __any_device__ + const Point_ &operator[](size_t offset) const { return Vertices[offset]; } + __any_device__ + Point_ &operator[](size_t offset) { return Vertices[offset]; } +}; + +template +struct Segment_ { + Point_ A, B; + + Segment_() = default; + __any_device__ + Segment_(const Point_ &a, const Point_ &b) : A(a), B(b) {} + + __any_device__ + T Length() const; + __any_device__ + T LengthSq() const; + __any_device__ + bool Intersection(const Segment_ &other, Point_ &out_ptAlong) const; +}; + +template +__any_device__ +__lib_inline__ Point_ operator+(const Point_ &a, const Point_ &b) { + return { a.X + b.X, a.Y + b.Y }; +} + +template +__any_device__ +__lib_inline__ Point_ operator-(const Point_ &a, const Point_ &b) { + return { a.X - b.X, a.Y - b.Y }; +} + +template +__any_device__ +__lib_inline__ Point_ operator*(W scale, const Point_ &p) { + return { scale * p.X, scale * p.Y }; +} + +template +__any_device__ +__lib_inline__ Point_ operator*(const Point_ &p, W scale) { + return { scale * p.X, scale * p.Y }; +} + +template +__any_device__ +__lib_inline__ Point_ operator/(const Point_ &p, W divisor) { + return { p.X / divisor, p.Y / divisor }; +} + +template +__any_device__ +__lib_inline__ Point_ operator*(const Point_ &a, const Point_ &b) { + return { a.X * b.X, a.Y * b.Y }; +} + +template +__any_device__ +__lib_inline__ Point_ operator-(const Point_ &p, W v) { + return { p.X - v, p.Y - v }; +} + +template +__any_device__ +__lib_inline__ Point_ &Point_::operator+=(const Point_ &p) { + X = X + p.X; + Y = Y + p.Y; + return *this; +} + +template +__any_device__ +__lib_inline__ Point_ &Point_::operator-=(const Point_ &p) { + X = X - p.X; + Y = Y - p.Y; + return *this; +} + +template +__any_device__ +__lib_inline__ Point_ &Point_::operator*=(const Point_ &p) { + X = X * p.X; + Y = Y * p.Y; + return *this; +} + +template +__any_device__ +__lib_inline__ Point_ &Point_::operator/=(const Point_ &p) { + X = X / p.X; + Y = Y / p.Y; + return *this; +} + +template +template +__any_device__ +__lib_inline__ Point_ &Point_::operator/=(W val) { + // TODO: This can be more efficient for float types by computing the reciprocal + X /= val; + Y /= val; + return *this; +} + +template +template +__any_device__ +__lib_inline__ Point_ &Point_::operator*=(W val) { + X *= val; + Y *= val; + return *this; +} + +template +__any_device__ +__lib_inline__ T dot(const Point_ &a, const Point_ &b) { + return a.X * b.X + a.Y * b.Y; +} + +template +__any_device__ +__lib_inline__ T dot(const Point_ &p) { + return dot(p, p); +} + +template +__any_device__ +__lib_inline__ T length(const Point_ &p) { +#ifndef __NVCC__ + using std::sqrt; +#endif + return sqrt(dot(p)); +} + +template +__any_device__ +__lib_inline__ Point_ normalize(const Point_ &p) { + static constexpr T epsilon = std::numeric_limits::epsilon(); + auto len = length(p) + epsilon; + return { p.X / len, p.Y / len }; +} + +template +__any_device__ +__lib_inline__ Point_ ortho_2d(const Point_ &p) { + return { -p.Y, p.X }; +} + +template +__host__ +__lib_inline__ std::ostream &operator<<(std::ostream &os, const Point_ &p) { + return os << "(" << p.X << ", " << p.Y << ")"; +} + +template +__host__ +__lib_inline__ std::ostream &operator<<(std::ostream &os, const AABB_ &b) { + return os << "[(" << b.X << ", " << b.Y << "), (" << b.MaxX << ", " << b.MaxY << ")]"; +} + +template +__host__ +__lib_inline__ std::ostream &operator<<(std::ostream &os, const Segment_ &s) { + return os << "[(" << s.A.X << ", " << s.A.Y << "), (" << s.B.X << ", " << s.B.Y << ")]"; +} + +template +__host__ +__lib_inline__ std::ostream &operator<<(std::ostream &os, const Quad_ &q) { + os << "[" << q.Vertices[0]; + for (size_t i = 1; i < 4; ++i) { + os << ", " << q.Vertices[i]; + } + return os << "]"; +} + +template +__any_device__ +__lib_inline__ int _signum(T val) { + return (T(0) < val) - (val < T(0)); +} + +template +__any_device__ +__lib_inline__ T sign(const Point_ &p1, const Point_ &p2, const Point_ &p3) { + T ret = (p1.X - p3.X) * (p2.Y - p3.Y) - (p2.X - p3.X) * (p1.Y - p3.Y); + auto sgn = _signum(ret); + return sgn; +} + +template +__any_device__ +__lib_inline__ T Segment_::Length() const +{ +#ifndef __NVCC__ + using std::sqrt; +#endif + return sqrt(LengthSq()); +} + +template +__any_device__ +__lib_inline__ T Segment_::LengthSq() const +{ + return dot(B - A); +} + +template +__any_device__ +inline bool Segment_::Intersection(const Segment_ &other, Point_ &out_ptAlong) const +{ + auto p1 = A, p2 = B, p3 = other.A, p4 = other.B; + + auto denom = (p4.Y - p3.Y) * (p2.X - p1.X) - (p4.X - p3.X) * (p2.Y - p1.Y); + + if (abs(denom) < 1e-8) { + return false; + } + + auto numer = (p4.X - p3.X) * (p1.Y - p3.Y) - (p4.Y - p3.Y) * (p1.X - p3.X); + + auto t = numer / denom; + + auto Bnumer = (p2.X - p1.X) * (p1.Y - p3.Y) - (p2.Y - p1.Y) * (p1.X - p3.X); + + auto Bt = Bnumer / denom; + + if (t < 0 || t > 1 || Bt < 0 || Bt > 1) { + return false; + } + + out_ptAlong = A + t * (B - A); + + return true; +} + +template +__any_device__ +auto quad_center(const quad_t &quad) -> Point_ +{ + typedef typename quad_t::inner_type T; + + Point_ center = quad[0]; + for (size_t i = 1; i < 4; ++i) { + center += quad[i]; + } + + return center / T{ 4 }; +} + +template +__any_device__ +Point_ QuadBase_::Center() const { + return quad_center(static_cast(*this)); +} + +template +__any_device__ +auto quad_bounds(const quad_t &quad) -> AABB_ +{ +#ifndef __NVCC__ + using std::min; + using std::max; +#endif + auto minP = quad[0]; + auto maxP = minP; + for (size_t i = 1; i < 4; ++i) { + auto qp = quad[i]; + minP = min(minP, qp); + maxP = max(maxP, qp); + } + return { minP.X, minP.Y, maxP.X, maxP.Y }; +} + +template +__any_device__ +AABB_ QuadBase_::Bounds() const { + return quad_bounds(static_cast(*this)); +} + +template +__any_device__ +inline bool quad_contains(const Quad_t &quad, const point_t &pt) +{ +#ifndef __NVCC__ + using std::abs; +#endif + + // Checks that the point lies on the interior side of each half plane + auto d1 = sign(pt, quad[0], quad[1]); + auto d2 = sign(pt, quad[1], quad[2]); + auto d3 = sign(pt, quad[2], quad[3]); + auto d4 = sign(pt, quad[3], quad[0]); + + // bool has_neg = (d1 < 0) || (d2 < 0) || (d3 < 0) || (d4 < 0); + // bool has_pos = (d1 > 0) || (d2 > 0) || (d3 > 0) || (d4 > 0); + int tot = d1 + d2 + d3 + d4; + + // return !(has_neg && has_pos); + return abs(tot) == 4; +} + +template +__any_device__ +__lib_inline__ bool QuadBase_::Contains(const Point_ &pt) const +{ + return quad_contains(static_cast(*this), pt); +} + +template +__any_device__ +inline auto shoelace_area(const PtList &points, size_t numPts, bool isSigned=false) -> decltype(points[0].X) +{ +#ifndef __NVCC__ + using std::abs; +#endif + + decltype(points[0].X) area = 0; + + size_t j = numPts - 1; + for (size_t i = 0; i < numPts; ++i) { + auto Pi = points[i]; + auto Pj = points[j]; + + area += (Pj.X + Pi.X) * (Pj.Y - Pi.Y); + j = i; + } + + area = area / 2; + + if (! isSigned) { + area = abs(area); + } + + return area; +} + +template +__any_device__ +__lib_inline__ T QuadBase_::Height() const +{ + auto &d = static_cast(*this); + auto h1 = Segment_(d[1], d[2]).Length(); + auto h2 = Segment_(d[3], d[0]).Length(); + return (h1 + h2) / 2; +} + +template +__any_device__ +__lib_inline__ T QuadBase_::Width() const +{ + auto &d = static_cast(*this); + auto w1 = Segment_(d[0], d[1]).Length(); + auto w2 = Segment_(d[3], d[2]).Length(); + return (w1 + w2) / 2; +} + +// A quad can be defined as the sum of the area of two triangles +template +__any_device__ +inline T QuadBase_::Area() const +{ + // auto vertices = static_cast(this)->Vertices; + return shoelace_area(static_cast(*this), 4); +} + +template +__any_device__ +inline auto intersection_area(const Quad_t1 &quadsA, const Quad_t2 &quadsB) -> typename Quad_t1::inner_type +{ +#ifndef __NVCC__ + using std::atan2; +#endif + + typedef typename Quad_t1::inner_type T; + + static const size_t MAX_PTS = 32; + + Point_ points[MAX_PTS], sortedPoints[MAX_PTS]; + T angles[MAX_PTS]; + size_t indices[MAX_PTS]; + size_t numPts = 0; + + auto addPt = [&] (const Point_ &p) { + points[numPts] = p; + ++numPts; + }; + + for (size_t i = 0; i < 4; ++i) { + Point_ aPt = quadsA[i]; + Point_ bPt = quadsB[i]; + + if (quadsA.Contains(bPt)) { + addPt(bPt); + } + if (quadsB.Contains(aPt)) { + addPt(aPt); + } + } + + for (size_t i = 0; i < 4; ++i) { + Segment_ segA{ quadsA[i], quadsA[(i + 1) % 4] }; + + for (size_t j = 0; j < 4; ++j) { + Segment_ segB{ quadsB[j], quadsB[(j + 1) % 4] }; + + Point_ ptAlong; + if (segA.Intersection(segB, ptAlong)) { + addPt(ptAlong); + } + } + } + + if (numPts == 0) { + return 0; + } + + Point_ center{ 0, 0 }; + for (size_t i = 0; i < numPts; ++i) { + center += points[i]; + } + center /= numPts; + + for (size_t i = 0; i < numPts; ++i) { + points[i] -= center; + + angles[i] = atan2(points[i].Y, points[i].X); + + indices[i] = i; + } + + // Perform an argsort over the angles + SORT_ALGO(indices, indices + numPts, + [&] (size_t a, size_t b) { + return angles[a] < angles[b]; + } + ); + + for (size_t i = 0; i < numPts; ++i) { + sortedPoints[i] = points[indices[i]]; + } + + // Finally, we can compute the area of this polygon using the shoelace formula + T area = shoelace_area(sortedPoints, numPts); + + return area; +} + +template +template +__any_device__ +__lib_inline__ T QuadBase_::IntersectionArea(const QuadBase_ &other) const +{ + return intersection_area( + static_cast(*this), + static_cast(other) + ); +} + +template +__any_device__ +__lib_inline__ auto geometry_iou(const T1 &a, const T2 &b) -> decltype(a.Area()) +{ + auto aArea = a.Area(); + auto bArea = b.Area(); + auto ixArea = a.IntersectionArea(b); + + auto unionArea = aArea + bArea - ixArea; + + return ixArea / unionArea; +} + +template +template +__any_device__ +__lib_inline__ T QuadBase_::IOU(const QuadBase_ &other) const +{ + return geometry_iou( + static_cast(*this), + static_cast(other) + ); +} + +template +template +__any_device__ +__lib_inline__ T QuadBase_::IOU_UpperBound(const QuadBase_ &other) const +{ + return geometry_iou( + Bounds(), + other.Bounds() + ); +} + +template +__any_device__ __lib_inline__ +auto geometry_region_sizes(const T1 &a, const T2 &b) -> tuple_t +{ + auto aArea = a.Area(); + auto bArea = b.Area(); + auto ixArea = a.IntersectionArea(b); + + auto unionArea = aArea + bArea - ixArea; + auto iou = ixArea / unionArea; + + return { ixArea / aArea, ixArea / bArea, iou }; +} + + +template +template +__any_device__ __lib_inline__ +tuple_t QuadBase_::RegionSizes(const QuadBase_ &other) const +{ + return geometry_region_sizes( + static_cast(*this), + static_cast(other) + ); +} + +template +template +__any_device__ __lib_inline__ +tuple_t QuadBase_::RegionSizes_UpperBound(const QuadBase_ &other) const +{ + return geometry_region_sizes( + Bounds(), + other.Bounds() + ); +} + +template +__any_device__ +auto polygon_bounds(const polygon_t &poly) -> AABB_ +{ +#ifndef __NVCC__ + using std::min; + using std::max; +#endif + auto minP = poly[0]; + auto maxP = minP; + for (size_t i = 1; i < poly.Count; ++i) { + auto qp = poly[i]; + minP = min(minP, qp); + maxP = max(maxP, qp); + } + return { minP.X, minP.Y, maxP.X, maxP.Y }; +} + +template +__any_device__ +AABB_ PolygonBase_::Bounds() const { + return polygon_bounds(static_cast(*this)); +} + +template +__any_device__ +bool polygon_contains(const polygon_t &poly, const point_t &pt) +{ + typedef typename polygon_t::inner_type T; + + // Some arbitrary segment. Technically this should be a ray, but functionally this will work + Segment_ testSeg{ pt, { -1e6, -2e6 }}; + Point_ trash; + + int32_t ixCount = 0; + for (size_t i = 0; i < poly.Count; ++i) { + Segment_ polySeg{ poly[i], poly[(i + 1) % poly.Count] }; + + if (testSeg.Intersection(polySeg, trash)) { + ++ixCount; + } + } + + // If there are an odd number of intersections, then the point is inside + return (ixCount % 2) == 1; +} + +template +__any_device__ +bool PolygonBase_::Contains(const Point_ &pt) const { + return polygon_contains(static_cast(*this), pt); +} + +template +__any_device__ +auto polygon_edge_length(const polygon_t &poly) -> typename polygon_t::inner_type +{ + typedef typename polygon_t::inner_type T; + + T ret = 0; + + for (size_t i = 0; i < poly.Count; ++i) { + Segment_ seg{ poly[i], poly[(i + 1) % poly.Count] }; + + ret += seg.Length(); + } + + return ret; +} + +template +__any_device__ +T PolygonBase_::EdgeLength() const { + return polygon_edge_length(static_cast(*this)); +} + +template +__any_device__ +auto polygon_center(const polygon_t &poly) -> Point_ +{ + typedef typename polygon_t::inner_type T; + + T cx = 0, cy = 0, a = 0; + size_t j = poly.Count - 1; + for (size_t i = 0; i < poly.Count; ++i) { + Point_ p0 = poly[i]; + Point_ p1 = poly[j]; + + T common = (p0.X * p1.Y - p1.X * p0.Y); + cx += (p0.X + p1.X) * common; + cy += (p0.Y + p1.Y) * common; + a += common; + + j = i; + } + + a /= 2; + + Point_ center{ cx / (6 * a), cy / (6 * a) }; + + return center; +} + +template +__any_device__ +Point_ PolygonBase_::Center() const { + return polygon_center(static_cast(*this)); +} + +template +__any_device__ +T PolygonBase_::Area() const { + const Derived &dThis = static_cast(*this); + return shoelace_area(dThis, dThis.Count); +} + + +template +__any_device__ +Point_ nearest_point_on_segment(const Point_ &pt, const Segment_ &seg) +{ +#ifndef __NVCC__ + using std::max; + using std::min; +#endif + + const T l2 = seg.LengthSq(); + + if (l2 == 0.0) { + return seg.A; + } + + const auto v = seg.A; + const auto w = seg.B; + // Consider the line extending the segment, parameterized as v + t*(w-v) + // Find projection of point p onto the line + auto t = dot(pt - v, w - v) / l2; + + // Clamp between t=0 and t=1 + t = max(static_cast(0), min(static_cast(1), t)); + + const auto projection = v + t * (w - v); + + return projection; +} + + +template +__any_device__ +Segment_ shortest_line_between_segments(const Segment_ &a, const Segment_ &b) +{ + Segment_ segs[] = { + { a.A, nearest_point_on_segment(a.A, b) }, + { a.B, nearest_point_on_segment(a.B, b) }, + { nearest_point_on_segment(b.A, a), b.A }, + { nearest_point_on_segment(b.B, a), b.B } + }; + + T minDist = std::numeric_limits::max(); + size_t idx; + + #pragma unroll + for (size_t i = 0; i < 4; ++i) { + T dist = segs[i].LengthSq(); + if (dist < minDist) { + minDist = dist; + idx = i; + } + } + + return segs[idx]; +} + +// Find the distance between a point and the nearest point along the specified segment +template +__any_device__ +T distance_to_segment(const Point_ &pt, const Segment_ &seg) +{ + auto projection = nearest_point_on_segment(pt, seg); + + auto dist = length(pt - projection); + + return dist; +} diff --git a/nemo-retriever-ocr/cpp/geometry_api/calc_poly_min_rrect.cpp b/nemo-retriever-ocr/cpp/geometry_api/calc_poly_min_rrect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ca87b0b0d887c291ed57abdac15ec25d9584d3f --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/calc_poly_min_rrect.cpp @@ -0,0 +1,165 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "geometry_api.h" + +#include "../graph_detection/encode_util.h" + +#include "../geometry.h" +#include "matrix2x2.h" + +using namespace std; + +template +void _calc_poly_min_rrect(const torch::TensorAccessor vertices, torch::TensorAccessor outRRect); +template +void _calc_quad_min_rrect(const torch::TensorAccessor vertices, torch::TensorAccessor outRRect); + +torch::Tensor calc_poly_min_rrect(torch::Tensor vertices) +{ + if (vertices.size(0) < 3) { + throw runtime_error("Invalid polygon! Expected >= 3 vertices, got " + to_string(vertices.size(0))); + } + + auto ret = torch::empty({ 4, 2 }, vertices.options()); + + auto retAcc = ret.accessor(); + + if (vertices.size(0) != 4) { + // OpenCV requires this to be a contiguous buffer + vertices = vertices.contiguous(); + _calc_poly_min_rrect(vertices.accessor(), retAcc); + } else { + _calc_quad_min_rrect(vertices.accessor(), retAcc); + } + + return ret; +} + + +template +void _calc_bounds(const torch::TensorAccessor &vertices, torch::TensorAccessor &outRRect, + const Point_ &leftCenter, const Point_ &rightCenter) +{ + typedef Point_ Pointf; + + Pointf vecAlong = rightCenter - leftCenter; + auto alongMag = length(vecAlong); + + if (alongMag == 0.0f) { + throw runtime_error("Invalid polygon!"); + } + + vecAlong /= alongMag; + + Pointf dOrtho{ -vecAlong.Y, vecAlong.X }; + + Pointf center = (leftCenter + rightCenter) / 2.0f; + + Matrix2x2 rotMat{ vecAlong, dOrtho }; + + auto get_fn = [&vertices, ¢er] (int64_t i) { + return Pointf{ vertices[i] } - center; + }; + + // All we care about it getting the bounds in the normalized space, so this saves + // us from having to do any memory allocation + Pointf minPt{ 0, 0 }, maxPt{ 0, 0 }; + auto tx_fn = [&minPt, &maxPt] (int64_t i, const Pointf &pt) { + minPt = min(minPt, pt); + maxPt = max(maxPt, pt); + }; + + matmul_fn(vertices.size(0), get_fn, rotMat, tx_fn, transpose_tag{}); + + Pointf rotBox[4] = { + minPt, + { maxPt.X, minPt.Y }, + maxPt, + { minPt.X, maxPt.Y } + }; + + auto get_fn2 = [&rotBox] (int64_t i) { + return rotBox[i]; + }; + + auto assign_fn = [¢er, &outRRect] (int64_t i, const Pointf &pt) { + outRRect[i][0] = pt.X + center.X; + outRRect[i][1] = pt.Y + center.Y; + }; + + matmul_fn(4, get_fn2, rotMat, assign_fn, contiguous_tag{}); +} + + +template +void _calc_poly_min_rrect(const torch::TensorAccessor vertices, torch::TensorAccessor outRRect) +{ + typedef Point_ Pointf; + typedef Polygon_ Polygonf; + + Polygonf poly{ vertices.data(), vertices.size(0) }; + + vector bottoms = graph_detection::find_bottom(poly, false); + + if (bottoms.size() != 2) { + throw runtime_error("Invalid polygon!"); + } + + vector longEdges[2]; + graph_detection::find_long_edges(poly, bottoms.data(), longEdges[0], longEdges[1]); + + //// + // Determine which edge is above the other + Pointf cpts[2]; + for (size_t i = 0; i < 2; ++i) { + auto &pedge = longEdges[i]; + + cpts[i] = Pointf{0.0f, 0.0f}; + float ct = 0; + for (size_t z = 0; z < pedge.size(); ++z) { + auto edge = pedge[z]; + Pointf p1 = poly[edge.A]; + Pointf p2 = poly[edge.B]; + cpts[i] += (p1 + p2) / 2.0f; + ct += 1.0f; + } + + if (ct < 1.0f) { + throw runtime_error("Edge was empty!"); + } + cpts[i] /= ct; + } + + float vpp = graph_detection::vector_sin(cpts[0] - cpts[1]); + if (vpp >= 0) { + swap(bottoms[0], bottoms[1]); + } + //// + + Pointf edge1[2] = { poly[bottoms[0].A], poly[bottoms[0].B] }; + Pointf edge2[2] = { poly[bottoms[1].A], poly[bottoms[1].B] }; + + Pointf c0 = (edge1[0] + edge1[1]) / 2.0f; + Pointf c1 = (edge2[0] + edge2[1]) / 2.0f; + + _calc_bounds(vertices, outRRect, c0, c1); +} + +template +void _calc_quad_min_rrect(const torch::TensorAccessor vertices, torch::TensorAccessor outRRect) +{ + typedef Point_ Pointf; + + // Instead of finding an arbitrary rotated box, find a reasonable + // fit for the quadrangle + Pointf pts[4] = { + vertices[0], vertices[1], vertices[2], vertices[3] + }; + + Pointf c0 = (pts[0] + pts[3]) / 2.0f; + Pointf c1 = (pts[1] + pts[2]) / 2.0f; + + _calc_bounds(vertices, outRRect, c0, c1); +} diff --git a/nemo-retriever-ocr/cpp/geometry_api/geometry_api.cpp b/nemo-retriever-ocr/cpp/geometry_api/geometry_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ccf852f53d4d6354fe524db238e8d6d32ee7970d --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/geometry_api.cpp @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "geometry_api.h" + +#include "geometry_api_common.h" + +using namespace std; + +torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize); + +template +torch::Tensor rrect_to_quads_impl(torch::Tensor rrects, T cellSize) +{ + // BHW(5) + auto rrectAccess = rrects.accessor(); + + T cellOff = cellSize / 2; + + auto quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options()); + + auto quadsAccess = quads.accessor(); + + for (long b = 0; b < rrects.size(0); ++b) { + for (long y = 0; y < rrects.size(1); ++y) { + for (long x = 0; x < rrects.size(2); ++x) { + auto rrect = rrectAccess[b][y][x]; + + auto quad = quadsAccess[b][y][x]; + + assign_rrect_to_quad(rrect, quad, cellSize, cellOff, + static_cast(x), + static_cast(y)); + } + } + } + + return quads; +} + +torch::Tensor rrect_to_quads(torch::Tensor rrects, float cellSize) +{ + if (rrects.is_cuda()) { + return rrect_to_quads_gpu(rrects, cellSize); + } + + torch::Tensor quads; + AT_DISPATCH_FLOATING_TYPES( + rrects.scalar_type(), + "rrect_to_quads_impl", + ([&] { + quads = rrect_to_quads_impl(rrects, scalar_t(cellSize)); + }) + ); + + return quads; +} + + +template +torch::Tensor rrect_to_quads_backward_impl(torch::Tensor rrects, torch::Tensor gradOutput) +{ + // BHW(5) + auto gradInput = torch::empty_like(rrects); + + auto rrectAccess = rrects.accessor(); + // BHW42 + auto gradOutputAccess = gradOutput.accessor(); + auto gradInputAccess = gradInput.accessor(); + + for (long b = 0; b < rrects.size(0); ++b) { + for (long y = 0; y < rrects.size(1); ++y) { + for (long x = 0; x < rrects.size(2); ++x) { + assign_grad_rrect_to_quad(rrectAccess[b][y][x], gradOutputAccess[b][y][x], gradInputAccess[b][y][x]); + } + } + } + + return gradInput; +} + +torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput); + +torch::Tensor rrect_to_quads_backward(torch::Tensor rrects, torch::Tensor gradOutput) +{ + if (rrects.is_cuda()) { + return rrect_to_quads_backward_gpu(rrects, gradOutput); + } + + torch::Tensor gradInput; + AT_DISPATCH_FLOATING_TYPES( + rrects.scalar_type(), + "rrect_to_quads_backward_impl", + ([&] { + gradInput = rrect_to_quads_backward_impl(rrects, gradOutput); + }) + ); + + return gradInput; +} diff --git a/nemo-retriever-ocr/cpp/geometry_api/geometry_api.h b/nemo-retriever-ocr/cpp/geometry_api/geometry_api.h new file mode 100644 index 0000000000000000000000000000000000000000..15e190b3e39ef62a909b265019220561c7f1789d --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/geometry_api.h @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +torch::Tensor rrect_to_quads(torch::Tensor rrects, float cellSize); +torch::Tensor rrect_to_quads_backward(torch::Tensor rrects, torch::Tensor gradOutput); + +torch::Tensor calc_poly_min_rrect(torch::Tensor vertices); + +float get_rel_continuation_cos(torch::Tensor rrectA, torch::Tensor rrectB); + +torch::Tensor get_poly_bounds_quad(torch::Tensor poly); diff --git a/nemo-retriever-ocr/cpp/geometry_api/geometry_api_common.h b/nemo-retriever-ocr/cpp/geometry_api/geometry_api_common.h new file mode 100644 index 0000000000000000000000000000000000000000..79217663fc843471430ce75001ff6d0c4b2b0def --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/geometry_api_common.h @@ -0,0 +1,121 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "../cuda_intellisense.cuh" +#include "../geometry.h" + +#if defined(__NVCC__) +#include +#define GEO_PI CUDART_PI_F +#else +#include +#define GEO_PI M_PI +#endif + + +template +__device__ +inline +void pt_assign(access_t acc, const point_t &p) { + acc[0] = p.X; + acc[1] = p.Y; +} + +template +__device__ __lib_inline__ +InPlaceQuad_ cvt_rrect_to_quad(const rrect_access_t &rrect, T cellSize, T cellOff, T x, T y) +{ + typedef Point_ Pointf; + + Pointf prior{ + x * cellSize + cellOff, + y * cellSize + cellOff + }; + + T dTop = rrect[0]; + T dRight = rrect[1]; + T dBottom = rrect[2]; + T dLeft = rrect[3]; + T theta = rrect[4]; + + T piOver2{GEO_PI / 2.0f}; + Pointf vX{ cos(theta), sin(theta) }; + Pointf vY{ cos(theta - piOver2), sin(theta - piOver2) }; + + InPlaceQuad_ ret; + + ret[0] = prior - vX * dLeft + vY * dTop; + ret[1] = prior + vX * dRight + vY * dTop; + ret[2] = prior + vX * dRight - vY * dBottom; + ret[3] = prior - vX * dLeft - vY * dBottom; + + return ret; +} + +template +__device__ __lib_inline__ +void assign_rrect_to_quad(const rrect_access_t &rrect, quad_access_t &quad, + T cellSize, T cellOff, T x, T y) +{ + const InPlaceQuad_ cvQuad = cvt_rrect_to_quad(rrect, cellSize, cellOff, x, y); + + const T *pInQuad = reinterpret_cast(&cvQuad); + T *pOutQuad = reinterpret_cast(quad.data()); + + #pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + pOutQuad[i] = pInQuad[i]; + } +} + +template +__device__ +inline +void assign_grad_rrect_to_quad(const rrect_access_t &rrect, + const quad_access_t &gradOutput, + rrect_access_t gradInput) +{ + typedef Point_ Pointf; + + T Top = rrect[0]; + T Right = rrect[1]; + T Bottom = rrect[2]; + T Left = rrect[3]; + T theta = rrect[4]; + + T piOver2{GEO_PI / 2.0f}; + Pointf vX{ cos(theta), sin(theta) }; + Pointf vY{ cos(theta - piOver2), sin(theta - piOver2) }; + + Pointf dVX{ -vX.Y, vX.X }; + Pointf dVY{ -vY.Y, vY.X }; + + Pointf gP0 = gradOutput[0], + gP1 = gradOutput[1], + gP2 = gradOutput[2], + gP3 = gradOutput[3]; + + // Top + gradInput[0] = (gP0 * vY + gP1 * vY).Sum(); + // Right + gradInput[1] = (gP1 * vX + gP2 * vX).Sum(); + // Bottom + gradInput[2] = -(gP2 * vY + gP3 * vY).Sum(); + // Left + gradInput[3] = -(gP0 * vX + gP3 * vX).Sum(); + + // Theta + gradInput[4] = ( + gP0 * (-Left * dVX + Top * dVY) + + gP1 * (Right * dVX + Top * dVY) + + gP2 * (Right * dVX - Bottom * dVY) + + gP3 * (-Left * dVX - Bottom * dVY) + ).Sum(); +} + +#undef GEO_PI diff --git a/nemo-retriever-ocr/cpp/geometry_api/geometry_api_gpu.cu b/nemo-retriever-ocr/cpp/geometry_api/geometry_api_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..555d41e0f66744bc2a7459fe2ed0254c08d3f132 --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/geometry_api_gpu.cu @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "geometry_api.h" + +#include "../geometry.h" +#include "../cuda_intellisense.cuh" +#include "geometry_api_common.h" + +#include + +using namespace std; + + +template +struct RRect_ { + T Data[5]; + + template + __device__ + const T &operator[](index_t i) const { return Data[i]; } + template + __device__ + T &operator[](index_t i) { return Data[i]; } +}; + +template +__global__ +void device_rrect_to_quads_gpu(torch::PackedTensorAccessor64 rrectAccess, + torch::PackedTensorAccessor64 quadsAccess, + int64_t numRows, int64_t numCols, + T cellSize) +{ + typedef Point_ Pointf; + typedef RRect_ RRectf; + typedef InPlaceQuad_ Quadf; + constexpr T TWO = 2; + + const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x; + + if (jobIdx >= rrectAccess.size(0)) { + return; + } + + int64_t row = jobIdx / numCols; + const int64_t col = jobIdx - (row * numCols); + row = row % numRows; + + auto rawRRect = reinterpret_cast(rrectAccess.data()); + auto rawQuad = reinterpret_cast(quadsAccess.data()); +#if defined(NDEBUG) + trove::coalesced_ptr pRRect(rawRRect); + trove::coalesced_ptr pQuad(rawQuad); +#else + auto pRRect = rawRRect; + auto pQuad = rawQuad; +#endif + + RRectf rrect = pRRect[jobIdx]; + + T cellOff = cellSize / TWO; + Quadf cvQuad = cvt_rrect_to_quad(rrect, cellSize, cellOff, col, row); + + pQuad[jobIdx] = cvQuad; +} + +torch::Tensor rrect_to_quads_gpu(torch::Tensor rrects, float cellSize) +{ + if (!rrects.is_contiguous()) { + throw std::runtime_error("Expected the rrects to be contiguous!"); + } + + torch::Tensor quads = torch::empty({ rrects.size(0), rrects.size(1), rrects.size(2), 4, 2 }, rrects.options()); + + auto rrFlat = rrects.flatten(0, 2); + auto qFlat = quads.flatten(0, 2); + + dim3 blockSize(96); + dim3 gridSize(div_up(qFlat.size(0), blockSize.x)); + + if (quads.numel() > 0) { + AT_DISPATCH_FLOATING_TYPES( + quads.scalar_type(), + "cuda_rrect_to_quads", + ([&] { + + device_rrect_to_quads_gpu KERNEL_ARG2(gridSize, blockSize) ( + rrFlat.packed_accessor64(), + qFlat.packed_accessor64(), + rrects.size(1), rrects.size(2), + cellSize + ); + + }) + ); + } + + return quads; +} + +template +__global__ +void device_rrect_to_quads_backward_gpu(torch::PackedTensorAccessor64 rrect, + torch::PackedTensorAccessor64 gradOutput, + torch::PackedTensorAccessor64 gradInput) +{ + const int64_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x; + + if (jobIdx >= rrect.size(0)) return; + + assign_grad_rrect_to_quad(rrect[jobIdx], gradOutput[jobIdx], gradInput[jobIdx]); +} + + +torch::Tensor rrect_to_quads_backward_gpu(torch::Tensor rrects, torch::Tensor gradOutput) +{ + auto gradInput = torch::empty_like(rrects); + + auto flatRRects = rrects.reshape({ -1, 5 }); + auto flatGradOutput = gradOutput.reshape({ -1, 4, 2 }); + auto flatGradInput = gradInput.reshape({ -1, 5 }); + + dim3 blockSize(32); + dim3 gridSize(div_up(rrects.size(0) * rrects.size(1) * rrects.size(2), blockSize.x)); + + if (rrects.numel() > 0) { + AT_DISPATCH_FLOATING_TYPES( + rrects.scalar_type(), + "cuda_rrect_to_quads_backward", + ([&] { + device_rrect_to_quads_backward_gpu KERNEL_ARG2(gridSize, blockSize) ( + flatRRects.packed_accessor64(), + flatGradOutput.packed_accessor64(), + flatGradInput.packed_accessor64() + ); + }) + ); + } + + return gradInput; +} diff --git a/nemo-retriever-ocr/cpp/geometry_api/get_rel_continuation_cos.cpp b/nemo-retriever-ocr/cpp/geometry_api/get_rel_continuation_cos.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e095f3d7545c65d762a02e1acd04e967a9004a5c --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/get_rel_continuation_cos.cpp @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "geometry_api.h" + +#include "../geometry.h" + +using namespace std; + + +float get_rel_continuation_cos(torch::Tensor rrectATensor, torch::Tensor rrectBTensor) +{ + typedef Point_ Pointf; + + if (rrectATensor.size(0) != 4 || rrectBTensor.size(0) != 4) { + throw runtime_error("Invalid rrect arguments. Both must have 4 vertices! A=" + + to_string(rrectATensor.size(0)) + ", B=" + to_string(rrectBTensor.size(0))); + } + + auto rrectA = rrectATensor.accessor(); + auto rrectB = rrectBTensor.accessor(); + + Pointf aPts[4] = { + rrectA[0], rrectA[1], rrectA[2], rrectA[3] + }; + + auto c1 = (aPts[0] + aPts[3]) / 2.0f; + auto c2 = (aPts[1] + aPts[2]) / 2.0f; + + auto aDir = c2 - c1; + auto aLen = length(aDir); + + if (aLen > 0) { + aDir /= aLen; + } else { + aDir = Pointf{ 1, 0 }; + } + + auto centerA = (c1 + c2) / 2.0f; + + Pointf bPts[4] = { + rrectB[0], rrectB[1], rrectB[2], rrectB[3] + }; + + auto centerB = (bPts[0] + bPts[1] + bPts[2] + bPts[3]) / 4.0f; + + auto connDir = centerB - centerA; + auto connLen = length(connDir); + + if (connLen == 0.0f) { + return 1.0f; + } + + connDir /= connLen; + + auto cosT = dot(aDir, connDir); + + return cosT; +} diff --git a/nemo-retriever-ocr/cpp/geometry_api/matrix2x2.h b/nemo-retriever-ocr/cpp/geometry_api/matrix2x2.h new file mode 100644 index 0000000000000000000000000000000000000000..70a1279fd26688d77b8e11c54265dc89dd568feb --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/matrix2x2.h @@ -0,0 +1,93 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../geometry.h" + + +struct contiguous_tag{}; + +struct transpose_tag{}; + +template +struct Matrix2x2_Offset; + +template +struct Matrix2x2_Offset +{ + static const uint32_t OFFSET = R * 2 + C; +}; + +template +struct Matrix2x2_Offset +{ + static const uint32_t OFFSET = C * 2 + R; +}; + + +template +struct Matrix2x2_Indexor +{ + static const uint32_t OFFSET = Matrix2x2_Offset::OFFSET; + + static T &get(T *data) { return data[OFFSET]; } + static const T get(const T *data) { return data[OFFSET]; } +}; + + +template +struct Matrix2x2 +{ + Matrix2x2() = default; + Matrix2x2(T r0c0, T r0c1, T r1c0, T r1c1) + : m_data{ r0c0, r0c1, r1c0, r1c1 } + { + } + Matrix2x2(const Point_ &r0, const Point_ &r1) + : m_data{ r0.X, r0.Y, r1.X, r1.Y } + { + } + Matrix2x2(const Point_ &r0, const Point_ &r1, transpose_tag) + : m_data{ r0.X, r1.X, r0.Y, r1.Y } + { + } + + inline T &operator[](uint32_t i) { return m_data[i]; } + inline const T operator[](uint32_t i) const { return m_data[i]; } + + T m_data[4]; +}; + +template +struct Matrix2x2_View +{ + Matrix2x2_View(const Matrix2x2 &m) : m_data(m.m_data) {} + + const T *m_data; +}; + +template +const T get(const Matrix2x2_View &m) +{ + return Matrix2x2_Indexor::get(m.m_data); +} + +template +inline +void matmul_fn(int64_t N, const get_pt_t &get_fn, const Matrix2x2 &mat, const callback_t &callback, + layout_t lt = layout_t{}) +{ + Matrix2x2_View m{ mat }; + + #pragma omp simd + for (int64_t i = 0; i < N; ++i) { + Point_ pt = get_fn(i); + + T x = pt.X * get<0, 0>(m) + pt.Y * get<1, 0>(m); + T y = pt.X * get<0, 1>(m) + pt.Y * get<1, 1>(m); + + callback(i, Point_{ x, y }); + } +} diff --git a/nemo-retriever-ocr/cpp/geometry_api/poly_bounds_quad.cpp b/nemo-retriever-ocr/cpp/geometry_api/poly_bounds_quad.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9cc8f5edb02417fab6b164998affe860655b14d --- /dev/null +++ b/nemo-retriever-ocr/cpp/geometry_api/poly_bounds_quad.cpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "geometry_api.h" + +using namespace std; + + +template +void pt_assign(torch::TensorAccessor acc, T x, T y) +{ + acc[0] = x; + acc[1] = y; +} + + +template +void poly_bounds_quad_impl(torch::TensorAccessor poly, torch::TensorAccessor outBounds) +{ + T minX = poly[0][0], + minY = poly[0][1], + maxX = poly[0][0], + maxY = poly[0][1]; + + const int64_t numVertices = poly.size(0); + + for (int64_t i = 0; i < numVertices; ++i) { + auto vert = poly[i]; + + minX = min(minX, vert[0]); + maxX = max(maxX, vert[0]); + + minY = min(minY, vert[1]); + maxY = max(maxY, vert[1]); + } + + pt_assign(outBounds[0], minX, minY); + pt_assign(outBounds[1], maxX, minY); + pt_assign(outBounds[2], maxX, maxY); + pt_assign(outBounds[3], minX, maxY); +} + + +torch::Tensor get_poly_bounds_quad(torch::Tensor poly) +{ + auto ret = torch::empty({ 4, 2 }, poly.options()); + + AT_DISPATCH_FLOATING_TYPES( + poly.scalar_type(), + "poly_bounds_quad_impl", + ([&] { + poly_bounds_quad_impl( + poly.accessor(), + ret.accessor() + ); + }) + ); + + return ret; +} diff --git a/nemo-retriever-ocr/cpp/graph_detection/encode_util.cpp b/nemo-retriever-ocr/cpp/graph_detection/encode_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad94f4c7db561611dbd2ad1f430203a34026704a --- /dev/null +++ b/nemo-retriever-ocr/cpp/graph_detection/encode_util.cpp @@ -0,0 +1,272 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "encode_util.h" + +#include +#include +#include + +#include "../third_party/clipper/clipper.hpp" + +using namespace std; + +namespace graph_detection { + +template +struct Candidate : Edge { + T C; + + Candidate() = default; + Candidate(int32_t a, int32_t b, T c) : Edge(a, b), C(c) {} +}; + +struct DistStruct { + Candidate A; + Candidate B; + float Dist; + + DistStruct() = default; + DistStruct(Candidate a, Candidate b, float dist) : A(a), B(b), Dist(dist) {} +}; + +template +float vec_cos(const Point_ &a, const Point_ &b) +{ + return dot(a, b) / (length(a) * length(b) + 1e-8); +} + +template> +vector arg_sort(const vector &vec, Fn comp = Fn()) +{ + vector ret; + ret.reserve(vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + ret.push_back(i); + } + + sort(begin(ret), end(ret), + [&vec, &comp] (size_t idxA, size_t idxB) { + return comp(vec[idxA], vec[idxB]); + } + ); + + return ret; +} + + +float edge_length(const Polygon_ &poly, const vector &edges); + +vector find_bottom(const Polygon_ &poly, bool useVertexOrder) +{ + if (poly.Count < 4) { + throw runtime_error("Invalid polygon. Fewer than 4 vertices!"); + } + + // If we trust the source of the geometries, then this saves us both computation, + // but can also be more reliable since we won't reorder the vertices + if (useVertexOrder) { + if ((poly.Count % 2) == 1) { + throw runtime_error("Can't use trusted vertex order when the vertex count is odd!"); + } + int32_t halfCt = poly.Count / 2; + return { { halfCt - 1, halfCt }, + { static_cast(poly.Count) - 1, 0 } }; + } + + if (poly.Count == 4) { + float d1 = length(poly[1] - poly[0]) + length(poly[2] - poly[3]); + float d2 = length(poly[2] - poly[1]) + length(poly[0] - poly[3]); + + if (4 * d1 < d2) { + return { { 0, 1 }, { 2, 3 } }; + } else { + return { { 1, 2 }, { 3, 0 } }; + } + } + + auto idx_wrap = [&poly] (size_t idx) { + return poly[idx % poly.Count]; + }; + + vector> candidates; + for (size_t i = 1; i < (poly.Count + 1); ++i) { + auto vPrev = idx_wrap(i) - idx_wrap(i - 1); + auto vNext = idx_wrap(i + 2) - idx_wrap(i + 1); + + // We're looking for the segment where the preceding and following segment + // essentially travel in opposite directions + if (vec_cos(vPrev, vNext) < -0.875f) { + auto currSeg = idx_wrap(i) - idx_wrap(i + 1); + candidates.emplace_back(i % poly.Count, (i + 1) % poly.Count, length(currSeg)); + } + } + + if (candidates.size() != 2 || candidates[0].A == candidates[1].B || candidates[0].B == candidates[1].A) { + // If candidate number < 2, or two bottom are joined, select 2 farthest edge + vector> midList; + for (size_t i = 0; i < poly.Count; ++i) { + Pointf midPoint = (idx_wrap(i) + idx_wrap(i + 1)) / 2.0f; + midList.emplace_back(i, (i + 1) % poly.Count, midPoint); + } + + vector distList; + + // Only found one good candidate, so search for the edge that's the furthest from this candidate + if (candidates.size() == 1) { + auto idx1a = candidates.back().A; + auto idx1b = candidates.back().B; + Candidate cand1{ idx1a, idx1b, (idx_wrap(idx1a) + idx_wrap(idx1b)) / 2.0f }; + for (size_t j = 0; j < poly.Count; ++j) { + auto &cand2 = midList[j]; + + if (cand1.Touches(cand2)) continue; + + float dist = length(cand1.C - cand2.C); + distList.emplace_back(cand1, cand2, dist); + } + } else { + for (size_t i = 0; i < poly.Count; ++i) { + for (size_t j = i + 1; j < poly.Count; ++j) { + auto &cand1 = midList[i]; + auto &cand2 = midList[j]; + + if (cand1.Touches(cand2)) continue; + + float dist = length(cand1.C - cand2.C); + distList.emplace_back(cand1, cand2, dist); + } + } + } + sort(begin(distList), end(distList), [] (auto a, auto b) { return a.Dist < b.Dist; }); + + if (distList.empty()) { + throw runtime_error("No valid bottom candidates found for this polygon!"); + } + + auto &bEdge = distList.back(); + return vector{ bEdge.A, bEdge.B }; + + } else { + return vector{ candidates[0], candidates[1] }; + } +} + +void find_long_edges(const Polygon_ &poly, Edge *bottoms, vector &outLongEdge1, vector &outLongEdge2) +{ + int32_t b1End = bottoms[0].B; + int32_t b2End = bottoms[1].B; + + int32_t nPoints = poly.Count; + + auto accum_into = [nPoints] (int32_t end1, int32_t end2, vector &outEdge) { + int32_t i = (end1 + 1) % nPoints; + while ((i % nPoints) != end2) { + int32_t start = i > 0 ? i - 1 : nPoints - 1; + int32_t end = i % nPoints; + outEdge.emplace_back(start, end); + i = (i + 1) % nPoints; + } + }; + + accum_into(b1End, b2End, outLongEdge1); + accum_into(b2End, b1End, outLongEdge2); +} + +float edge_length(const Polygon_ &poly, const vector &edges) +{ + float ret = 0.0f; + for (const Edge &e : edges) { + ret += length(poly[e.B] - poly[e.A]); + } + return ret; +} + +vector edge_lengths(const Polygon_ &poly, const vector &edges) +{ + if (edges.empty()) { + throw runtime_error("Found an empty edge!"); + } + + vector ret; + ret.reserve(edges.size()); + + for (const Edge &e : edges) { + ret.push_back(length(poly[e.B] - poly[e.A])); + } + + return ret; +} + +void split_edge_sequence(const Polygon_ &poly, const vector &edges, + const vector &edgeLengths, float nParts, + vector &outPts); + +void split_edge_sequence_by_step(const Polygon_ &poly, const vector &longEdge1, const vector &longEdge2, + float step, vector &outInnerPoints1, vector &outInnerPoints2) +{ + auto edgeLengths1 = edge_lengths(poly, longEdge1); + auto edgeLengths2 = edge_lengths(poly, longEdge2); + + float totalLength = (accumulate(begin(edgeLengths1), end(edgeLengths1), 0.0f) + accumulate(begin(edgeLengths2), end(edgeLengths2), 0.0f)) / 2; + + float nParts = max(ceil(totalLength / step), 2); + + split_edge_sequence(poly, longEdge1, edgeLengths1, nParts, outInnerPoints1); + split_edge_sequence(poly, longEdge2, edgeLengths2, nParts, outInnerPoints2); +} + +void split_edge_sequence(const Polygon_ &poly, const vector &edges, + const vector &edgeLengths, float nParts, + vector &outPts) +{ + vector elCumSum = vec_cumsum(edgeLengths); + + float totalLength = elCumSum.back(); + float lengthPerPart = totalLength / (nParts - 1); + + size_t iNumParts = nParts; + size_t currNode = 0; + size_t ctr = 0; + for (float i = 0.0f; ctr < iNumParts; i += 1.0f, ++ctr) { + float t = min(i * lengthPerPart, totalLength); + + while (t > elCumSum[currNode + 1]) { + ++currNode; + } + + Edge currEdge = edges[currNode]; + Pointf e1 = poly[currEdge.A]; + Pointf e2 = poly[currEdge.B]; + + float currLen = edgeLengths[currNode]; + + Pointf sampledPt; + + if (currLen > 0) { + float deltaT = t - elCumSum[currNode]; + float ratio = deltaT / currLen; + sampledPt = e1 + ratio * (e2 - e1); + } else { + sampledPt = e1; + } + + outPts.push_back(sampledPt); + } +} + +string print_poly(const Polyf &poly) { + ostringstream oss; + oss << "["; + for (size_t i = 0; i < poly.Count; ++i) { + if (i > 0) { + oss << ", "; + } + oss << "(" << poly[i].X << ", " << poly[i].Y << ")"; + } + oss << "]"; + return oss.str(); +} + +} // namespace graph_detection diff --git a/nemo-retriever-ocr/cpp/graph_detection/encode_util.h b/nemo-retriever-ocr/cpp/graph_detection/encode_util.h new file mode 100644 index 0000000000000000000000000000000000000000..640992085e710e09a7c8f455e51abeff39800257 --- /dev/null +++ b/nemo-retriever-ocr/cpp/graph_detection/encode_util.h @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "../geometry.h" + +namespace graph_detection { + + + +struct Edge { + int32_t A; + int32_t B; + + Edge() = default; + Edge(int32_t a, int32_t b) : A(a), B(b) {} + + bool Touches(int32_t idx) const { return A == idx || B == idx; } + bool Touches(const Edge &other) const; +}; + +inline +bool edge_touches(const Edge &edge, int32_t vertex) { + return edge.A == vertex || edge.B == vertex; +} + +inline +bool Edge::Touches(const Edge &other) const { + return edge_touches(other, A) || edge_touches(other, B); +} + +typedef Point_ Pointf; +typedef AABB_ AABBf; +typedef Polygon_ Polyf; +typedef std::vector Polyline; + +std::vector find_bottom(const Polygon_ &poly, bool useVertexOrder); + +void find_long_edges(const Polygon_ &poly, Edge *bottoms, std::vector &outLongEdge1, std::vector &outLongEdge2); + +void split_edge_sequence_by_step(const Polygon_ &poly, const std::vector &longEdge1, const std::vector &longEdge2, + float step, std::vector &outInnerPoints1, std::vector &outInnerPoints2); + +std::string print_poly(const Polyf &poly); + +template +inline +std::vector vec_cumsum(const std::vector &v) +{ + std::vector ret; + ret.reserve(v.size() + 1); + ret.push_back(0); + for (T val : v) { + ret.push_back(ret.back() + val); + } + return ret; +} + +template +inline +void n_choose_k(size_t n, size_t k, RandEng &randEng, Fn fn) +{ + if (k == 0) return; + + // TODO(mranzinger): This algorithm can be replaced with sampling from a geometric + // distribution, which drastically reduces the runtime complexity + for (size_t i = 0; i < n; ++i) { + size_t leftover = n - i; + if (leftover <= k) { + fn(i); + --k; + } else { + float p = std::uniform_real_distribution(0.0f, 1.0f)(randEng); + float probSample = float{k} / float{leftover}; + if (p < probSample) { + fn(i); + --k; + } + } + } +} + +template +inline T clamp(T val, T minVal, T maxVal) { + return std::max(std::min(val, maxVal), minVal); +} + +inline +Pointf avg_point(const std::vector &points) +{ + return std::accumulate(std::begin(points), std::end(points), Pointf(0,0)) / float(points.size()); +} + +inline +float vector_sin(const Pointf &pt) +{ + // sin = y / len(pt) + return pt.Y / (length(pt) + 1e-8); +} + +inline +float vector_cos(const Pointf &pt) +{ + // cos = x / len(pt) + return pt.X / (length(pt) + 1e-8); +} + +inline +void vector_cos_sin(const Pointf & pt, float &outCos, float &outSin) +{ + float len = length(pt) + 1e-8; + outCos = pt.X / len; + outSin = pt.Y / len; +} + +inline +float point_dist_to_line(const Pointf &l1, const Pointf &l2, const Pointf &pt) +{ + auto d = l2 - l1; + + auto lineLen = length(d); + + if (lineLen > 0) { + float distance = abs( + d.Y * pt.X + - d.X * pt.Y + + l2.X * l1.Y + - l2.Y * l1.X + ) / lineLen; + return distance; + } else { + return length(pt - l1); + } +} + +template +T find_mode(std::vector &inputs) { + using std::sort; + using std::begin; + using std::end; + + if (inputs.empty()) { + throw std::runtime_error("Cannot find mode of empty distribution!"); + } + + sort(begin(inputs), end(inputs)); + + T currVal = inputs[0]; + size_t currCount = 1; + + T modeVal = inputs[0]; + size_t modeCount = 1; + + auto commitCurr = [&] () { + if (currCount > modeCount) { + modeCount = currCount; + modeVal = currVal; + } + }; + + for (size_t i = 1; i < inputs.size(); ++i) { + if (inputs[i] == currVal) { + ++currCount; + } else { + // Start of a new value + commitCurr(); + + currCount = 1; + currVal = inputs[i]; + } + } + + commitCurr(); + + return modeVal; +} + +} // namespace graph_detection diff --git a/nemo-retriever-ocr/cpp/half_ops.cu b/nemo-retriever-ocr/cpp/half_ops.cu new file mode 100644 index 0000000000000000000000000000000000000000..48082f2d2353a2e3f5d92248750aa28caa75fb42 --- /dev/null +++ b/nemo-retriever-ocr/cpp/half_ops.cu @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "half_ops.cuh" diff --git a/nemo-retriever-ocr/cpp/half_ops.cuh b/nemo-retriever-ocr/cpp/half_ops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ea04ad6cf1c1b00b83a671f0cfe0aeaecbab7835 --- /dev/null +++ b/nemo-retriever-ocr/cpp/half_ops.cuh @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "cuda_intellisense.cuh" + +#ifndef __CUDACC__ +#pragma message("__CUDACC__ not defined!") +#else +#pragma message("__CUDACC__ defined!") +#endif + +#ifdef __NVCC__ +#define __qr_device__ __device__ +#define __qr_host__ __host__ +#define __qr_inline__ __forceinline__ +#else +#define __qr_device__ +#define __qr_host__ +#define __qr_inline__ inline +#endif + +#ifdef __CUDACC__ +#include +#include +#include + + +__qr_inline__ __device__ __half operator-(__half v) { + return __hneg(v); +} + +__qr_inline__ __device__ __half operator+(__half a, __half b) { + return __hadd(a, b); +} + +__qr_inline__ __device__ __half operator-(__half a, __half b) { + return __hsub(a, b); +} + +__qr_inline__ __device__ __half operator*(__half a, __half b) { + return __hmul(a, b); +} + +__qr_inline__ __device__ __half operator/(__half a, __half b) { + return __hdiv(a, b); +} + +__qr_inline__ __device__ bool operator==(__half a, __half b) { + return __heq(a, b); +} + +__qr_inline__ __device__ bool operator<(__half a, __half b) { + return __hlt(a, b); +} + +__qr_inline__ __device__ bool operator>(__half a, __half b) { + return __hgt(a, b); +} + +__qr_inline__ __device__ __half sqrt(__half v) { + return hsqrt(v); +} + +__qr_inline__ __device__ __half floor(__half v) { + return hfloor(v); +} + +__qr_inline__ __device__ __half ceil(__half v) { + return hceil(v); +} + +__qr_inline__ __device__ __half max(__half a, __half b) { + return a > b ? a : b; +} +#endif //__CUDACC__ + +template +struct Convert { + __qr_inline__ static __qr_host__ __qr_device__ constexpr Dest From(Src value) { return static_cast(value); } + __qr_inline__ static __qr_host__ __qr_device__ constexpr Src To(Dest value) { return static_cast(value); } + __qr_inline__ static __qr_host__ __qr_device__ constexpr Dest LeftToRight(Src value) { return static_cast(value); } + __qr_inline__ static __qr_host__ __qr_device__ constexpr Src RightToLeft(Dest value) { return static_cast(value); } +}; + +#ifdef __CUDACC__ +template<> +struct Convert<__half, float> { + __qr_inline__ static __host__ __device__ float From(__half value) { return __half2float(value); } + __qr_inline__ static __host__ __device__ __half To(float value) { return __float2half(value); } + __qr_inline__ static __host__ __device__ float LeftToRight(__half value) { return __half2float(value); } + __qr_inline__ static __host__ __device__ __half RightToLeft(float value) { return __float2half(value); } +}; + +template +struct Convert<__half, Dest> : Convert<__half, float> { + +}; + +namespace at { + +template<> +inline __half* TensorBase::mutable_data_ptr() const { + TORCH_CHECK(scalar_type() == ScalarType::Half, + "expected scalar type Half but found ", + c10::toString(scalar_type())); + return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data()); +} + +template<> +inline __half* TensorBase::data_ptr() const { + TORCH_CHECK(scalar_type() == ScalarType::Half, + "expected scalar type Half but found ", + c10::toString(scalar_type())); + return static_cast<__half*>(this->unsafeGetTensorImpl()->mutable_data()); +} + +} + +template +struct remap_half { + typedef T type; +}; + +template<> +struct remap_half { + typedef __half type; +}; + +template +__half to_half(T val) { + return Convert<__half, T>::RightToLeft(val); +} + +template +struct fp_promote { + typedef T type; +}; + +template<> +struct fp_promote<__half> { + typedef float type; +}; + +#endif //__CUDACC__ diff --git a/nemo-retriever-ocr/cpp/local_ips/local_ips.h b/nemo-retriever-ocr/cpp/local_ips/local_ips.h new file mode 100644 index 0000000000000000000000000000000000000000..cd2a836be1fc89c7d41947a592e3a63cee5d0795 --- /dev/null +++ b/nemo-retriever-ocr/cpp/local_ips/local_ips.h @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +torch::Tensor ragged_quad_all_2_all_distance_v2(torch::Tensor embedQuads, torch::Tensor quadsPerExample, + float xFactor, float yFactor, + bool allowSelfDistance); diff --git a/nemo-retriever-ocr/cpp/local_ips/quad_all_2_all_dist_v2.cu b/nemo-retriever-ocr/cpp/local_ips/quad_all_2_all_dist_v2.cu new file mode 100644 index 0000000000000000000000000000000000000000..1e2a17a5eae2989c3b7a17b852c1dec8fdd9a45d --- /dev/null +++ b/nemo-retriever-ocr/cpp/local_ips/quad_all_2_all_dist_v2.cu @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + + +#include + +#include +#include + +#include +#include + +#include "local_ips.h" +#include "../cuda_intellisense.cuh" +#include "../common.h" +#include "../geometry.h" + +using namespace std; +namespace cg = cooperative_groups; + +typedef Point_ Pointf; + +__device__ inline +float square(float val) { return val * val; } + +__global__ +void device_quad_all_2_all_distance_v2(torch::PackedTensorAccessor64 allEmbedQuads, + torch::PackedTensorAccessor64 allRegionCounts, + torch::PackedTensorAccessor64 csWorkPerExample, + torch::PackedTensorAccessor64 outDistances, + float xFactor, float yFactor, + bool allowSelfDistance) +{ + // Note that the blockIdx.x is on purpose here + int64_t workIdx = blockIdx.x * blockDim.y + threadIdx.y; + + if (workIdx >= csWorkPerExample[csWorkPerExample.size(0) - 1]) return; + + auto exIter = thrust::upper_bound(thrust::seq, + csWorkPerExample.data(), csWorkPerExample.data() + csWorkPerExample.size(0), + workIdx); + + const int64_t exIdx = exIter - csWorkPerExample.data(); + + const int64_t workStart = exIdx == 0 ? 0 : csWorkPerExample[exIdx - 1]; + const int64_t workOff = workIdx - workStart; + + const int64_t row = workOff / allRegionCounts[exIdx]; + const int64_t col = workOff % allRegionCounts[exIdx]; + + auto taRowQuad = allEmbedQuads[exIdx][row]; + auto taColQuad = allEmbedQuads[exIdx][col]; + + Quad_ rowQuad(taRowQuad.data()), + colQuad(taColQuad.data()); + + auto p1 = (rowQuad[0] + rowQuad[3]) / 2.0f; + auto p2 = (rowQuad[1] + rowQuad[2]) / 2.0f; + + auto vX = p2 - p1; + auto lenVX = length(vX); + if (lenVX > 0) { + vX = vX / max(lenVX, 1e-8f); + } else { + vX = { 1, 0 }; + } + + Pointf vY{ -vX.Y, vX.X }; + + auto reproj = [&vX, &vY, xFactor, yFactor] (const Pointf &pt) { + auto dX = dot(pt, vX); + if (dX >= 0) { + dX *= xFactor; + } + auto dY = dot(pt, vY); + if (dY >= 0) { + dY *= yFactor; + } + + return Pointf{ dX, dY }; + }; + + auto tile16 = cg::tiled_partition<16>(cg::this_thread_block()); + + // Figure out which vertices this thread is processing + const int64_t rowVertexIdx = tile16.thread_rank() / 4; + const int64_t colVertexIdx = tile16.thread_rank() % 4; + + float dist; + if (row != col) { + Segment_ rowSeg{ rowQuad[rowVertexIdx], rowQuad[(rowVertexIdx + 1) % 4] }; + Segment_ colSeg{ colQuad[colVertexIdx], colQuad[(colVertexIdx + 1) % 4] }; + + Segment_ minSeg = shortest_line_between_segments(rowSeg, colSeg); + + Point_ vSeg = minSeg.B - minSeg.A; + + vSeg = reproj(vSeg); + + dist = length(vSeg); + } else if (allowSelfDistance) { + dist = 0; + } else { + dist = numeric_limits::infinity(); + } + + // Now find the minimum distance across the group + int lane = tile16.thread_rank(); + // Each iteration halves the number of active threads + // Each thread gets the partial min[i] to min[lane+i] + #pragma unroll + for (uint32_t i = 1; i < 16; i <<= 1) { + auto otherDist = tile16.shfl_down(dist, i); + dist = min(dist, otherDist); + } + +#ifndef NDEBUG + float lowestDist = tile16.shfl(dist, 0); + assert(dist >= lowestDist); +#endif + + if (lane == 0) { + outDistances[exIdx][row][col] = dist; + } +} + +torch::Tensor ragged_quad_all_2_all_distance_v2(torch::Tensor embedQuads, torch::Tensor regionCounts, + float xFactor, float yFactor, + bool allowSelfDistance) +{ + if (!embedQuads.is_contiguous()) { + throw std::runtime_error("Expected `embedQuads` to be contiguous!"); + } + + auto outDistances = torch::zeros({ embedQuads.size(0), embedQuads.size(1), embedQuads.size(1) }, + embedQuads.options()); + + if (embedQuads.numel() == 0) { + return outDistances; + } + + auto workPerExample = regionCounts * regionCounts; + + auto csWorkPerExample = torch::cumsum(workPerExample, 0); + + int64_t totalWork = csWorkPerExample[-1].item(); + + dim3 blockSize(16, 2); + dim3 gridSize(div_up(totalWork, blockSize.y), 1); + + device_quad_all_2_all_distance_v2 KERNEL_ARG2(gridSize, blockSize) ( + embedQuads.packed_accessor64(), + regionCounts.packed_accessor64(), + csWorkPerExample.packed_accessor64(), + outDistances.packed_accessor64(), + xFactor, yFactor, + allowSelfDistance + ); + + return outDistances; +} diff --git a/nemo-retriever-ocr/cpp/module.cpp b/nemo-retriever-ocr/cpp/module.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb15853a218599eaf78d96ccbab7f0ed6369f95b --- /dev/null +++ b/nemo-retriever-ocr/cpp/module.cpp @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + + +#include "quad_rectify/quad_rectify.h" +#include "non_maximal_suppression/non_maximal_suppression.h" +#include "geometry_api/geometry_api.h" +#include "beam_decode/beam_decode.h" +#include "better_grid_sample/grid_sample.h" +#include "sparse_select/sparse_select.h" +#include "text_region_grouping/text_region_grouping.h" +#include "local_ips/local_ips.h" + +#include +#include + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("quad_rectify_calc_quad_width", &quad_rectify_calc_quad_width, + "Quad Rectify Calc Quad Width C++", + py::arg("quads"), + py::arg("output_height"), + py::arg("round_factor") = 16, + py::arg("max_width") = 0 + ); + m.def("quad_rectify_forward", &quad_rectify_forward, "Quad Rectify Forward C++", + py::arg("quads"), + py::arg("image_height"), py::arg("image_width"), + py::arg("output_height"), py::arg("output_width"), + py::arg("isotropic") = true + ); + m.def("quad_rectify_backward", &quad_rectify_backward, "Quad Rectify Backward C++", + py::arg("quads"), py::arg("grad_output"), + py::arg("image_height"), py::arg("image_width"), + py::arg("isotropic") = true + ); + m.def("quad_non_maximal_suppression", &quad_non_maximal_suppression, "Quad Non-Maximal Suppression C++", + py::arg("quads"), py::arg("probs"), + py::arg("prob_threshold"), py::arg("iou_threshold"), + py::arg("kernel_height"), py::arg("kernel_width"), + py::arg("max_regions"), + py::arg("verbose") = false + ); + + py::class_(m, "LanguageModel"); + + m.def("beam_decode", &beam_decode, "beam_decode c++", + py::arg("probs"), + py::arg("beam_size") = 100, + py::arg("blank") = 0, + py::arg("min_prob") = 0.001, + py::arg("lang_model") = static_cast(nullptr), + py::arg("lm_weight") = 1, + py::arg("combine_duplicates") = true + ); + + py::class_(m, "TokenMapping"); + + m.def("create_token_mapping", &create_token_mapping, "create token mapping c++", + py::arg("token_mapping") + ); + + m.def("decode_sequences", &decode_sequences, "decode_sequences c++", + py::arg("tokens"), py::arg("language_model"), + py::arg("probs") = nullptr + ); + + m.def("create_sbo_lm", &create_sbo_lm, "create_sbo_lm c++", + py::arg("data_file_path"), + py::arg("token_mapping"), + py::arg("backoff") = 0.4 + ); + + m.def("indirect_grid_sample_forward", &indirect_grid_sample_forward, "indirect_grid_sample::forward c++", + py::arg("input"), py::arg("grid"), py::arg("input_indices"), py::arg("method") + ); + m.def("indirect_grad_sample_backward", &indirect_grad_sample_backward, "indirect_grid_sample::backward c++", + py::arg("grad_output"), py::arg("input"), py::arg("grid"), py::arg("input_indices"), py::arg("method") + ); + m.def("region_counts_to_indices", ®ion_counts_to_indices, "region counts to indices", + py::arg("region_counts"), py::arg("num_outputs") + ); + + m.def("rrect_to_quads", &rrect_to_quads, "convert rotated rectangle to quadrangles", + py::arg("rrects"), py::arg("cell_size") + ); + m.def("rrect_to_quads_backward", &rrect_to_quads_backward, "gradient of rrect_to_quads", + py::arg("rrects"), py::arg("grad_output") + ); + + m.def("sparse_select", &sparse_select, "Select sparse tensor(s) given a set of indices", + py::arg("sparse_counts"), py::arg("sparse_tensors"), py::arg("select_indices") + ); + + m.def("text_region_grouping", &text_region_grouping, "Clusters all of the text into lines and phrases", + py::arg("quads"), py::arg("counts"), + py::arg("horizontal_tolerance") = 2.0f, + py::arg("vertical_tolerance") = 0.5f, + py::arg("verbose") = false + ); + + m.def("dense_relations_to_graph", &dense_relations_to_graph, "Converts a dense relational tensor to a graph", + py::arg("relations") + ); + + m.def("ragged_quad_all_2_all_distance_v2", &ragged_quad_all_2_all_distance_v2, "get the all-to-all distances in ragged-batch quad mode", + py::arg("embed_quads"), py::arg("region_counts"), + py::arg("x_factor") = 1.0f, + py::arg("y_factor") = 1.0f, + py::arg("allow_self_distance") = true + ); + + m.def("calc_poly_min_rrect", &calc_poly_min_rrect, "calculate a reasonable bounding rectangle for a given text polygon", + py::arg("vertices") + ); + + m.def("get_rel_continuation_cos", &get_rel_continuation_cos, "c++ get relation cosine between 2 regions", + py::arg("rrect_a"), py::arg("rrect_b") + ); + + m.def("get_poly_bounds_quad", &get_poly_bounds_quad, "c++ get polygon bounds", + py::arg("poly") + ); +} diff --git a/nemo-retriever-ocr/cpp/non_maximal_suppression/cpu_non_maximal_suppression.cpp b/nemo-retriever-ocr/cpp/non_maximal_suppression/cpu_non_maximal_suppression.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffd6d725155993514e0bccf1adf44205e33e602d --- /dev/null +++ b/nemo-retriever-ocr/cpp/non_maximal_suppression/cpu_non_maximal_suppression.cpp @@ -0,0 +1,209 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "non_maximal_suppression.h" + +#include +#include "../geometry.h" + +using namespace std; + + +template +void visit_node( + const torch::TensorAccessor &quads, + const torch::TensorAccessor &probs, + const torch::TensorAccessor &adjacency, + MergeQuad_ &mQuad, + unordered_set &visited, + int64_t r, int64_t c, int32_t vIdx) +{ + if (visited.count(vIdx)) { + return; + } + visited.insert(vIdx); + + int32_t *pAdj = adjacency[r][c].data(); + + int32_t adjCt = pAdj[0]; + assert(adjCt > 0); + + mQuad.Append(Quad_(quads[r][c].data()), probs[r][c]); + + int32_t *pOff = pAdj + 2; + int32_t *pEnd = pAdj + adjCt + 1; + + const int32_t W = quads.size(1); + + for (; pOff != pEnd; ++pOff) { + int32_t vIdx2 = *pOff; + int32_t r2 = vIdx2 / W; + int32_t c2 = vIdx2 % W; + + visit_node(quads, probs, adjacency, mQuad, visited, r2, c2, vIdx2); + } +} + +template +std::vector quad_nms_from_adjacency_impl( + const torch::TensorAccessor &quads, + const torch::TensorAccessor &probs, + const torch::TensorAccessor &adjacency, + scalar_t probThreshold, scalar_t iouThreshold, + int64_t maxRegions) +{ + const uint64_t B = quads.size((int)0); + const int64_t H = quads.size((int)1); + const int64_t W = quads.size((int)2); + + typedef MergeQuad_ MQuad; + typedef EmbedQuad_ EFQuad; + + vector> batchQuads{ static_cast< const unsigned int >( B ) }; + vector> allQuads{ static_cast< const unsigned int >( B ) }; + vector>> batchAdjIdxs{ static_cast< const unsigned int >( B ) }; + + #pragma omp parallel num_threads (8) + { + #pragma omp for + for (int64_t b = 0; b < B; ++b) { + unordered_set visited; + + for (int64_t r = 0; r < H; ++r) { + for (int64_t c = 0; c < W; ++c) { + auto currProb = probs[b][r][c]; + + if (currProb < probThreshold) { + continue; + } + + int32_t vIdx = r * W + c; + + // Ensure that this quad hasn't already been merged + if (visited.count(vIdx)) { + continue; + } + + MQuad mQuad{ZeroInitTag{}}; + visit_node(quads[b], probs[b], adjacency[b], mQuad, visited, r, c, vIdx); + + batchQuads[b].push_back(mQuad.Commit()); + } + } + } + + #pragma omp single + { + for (size_t b = 0; b < B; ++b) { + size_t numQuads = batchQuads[b].size(); + batchAdjIdxs[b].resize(numQuads); + for (int64_t n = 0; n < numQuads; ++n) { + #pragma omp task default(none) shared(batchAdjIdxs, batchQuads, iouThreshold) firstprivate(b, numQuads, n) + { + for (int64_t m = n + 1; m < numQuads; ++m) { + vector &adjIdxs = batchAdjIdxs[b][n]; + vector &quads = batchQuads[b]; + auto iou = quads[n].IOU(quads[m]); + + if (iou > iouThreshold) { + adjIdxs.push_back(m); + } + } + } + } + } + + #pragma omp taskwait + } + + #pragma omp for + for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) { + vector> &adjIdxs = batchAdjIdxs[batchIdx]; + vector &quads = batchQuads[batchIdx]; + vector &finalQuads = allQuads[batchIdx]; + + // Step 3: Using depth first search, merge the regions + unordered_set visited; + for (int64_t n = 0; n < quads.size(); ++n) { + EFQuad currQuad; + visit_node(quads, n, adjIdxs, currQuad, visited); + + if (currQuad.NumQuads > 0) { + currQuad.Prepare(); + + finalQuads.push_back(currQuad); + } + } + + // Only sort the part that we want to keep + partial_sort(begin(finalQuads), + begin(finalQuads) + std::min(finalQuads.size(), maxRegions), + end(finalQuads), + [] (auto a, auto b) { + return a.Confidence > b.Confidence; + } + ); + + // Truncate the low confidence regions + if (finalQuads.size() > maxRegions) { + finalQuads.resize(maxRegions); + } + + //cout << "Ex " << batchIdx << " quads:" << endl << finalQuads << endl << endl; + } + + } // End parallel + + int64_t numOutQuads = 0; + for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) { + numOutQuads += allQuads[batchIdx].size(); + } + + // Step 4: Convert the quads into tensor representation + auto outQuadTensor = torch::empty({ numOutQuads, 4, 2 }, torch::kFloat32); + auto outConfTensor = torch::empty({ numOutQuads }, torch::kFloat32); + torch::Tensor outCountTensor = torch::empty({ static_cast( allQuads.size() ) }, torch::kInt64); + + auto outQuadAccess = outQuadTensor.accessor(); + auto outConfAccess = outConfTensor.accessor(); + auto outCountAccess = outCountTensor.accessor(); + + int64_t offset = 0; + for (int64_t batchIdx = 0; batchIdx < allQuads.size(); ++batchIdx) { + vector &exQuads = allQuads[batchIdx]; + + outCountAccess[batchIdx] = exQuads.size(); + + for (int64_t qIdx = 0; qIdx < exQuads.size(); ++qIdx, ++offset) { + copy_quad(exQuads[qIdx], outQuadAccess[offset].data()); + outConfAccess[offset] = exQuads[qIdx].Confidence; + } + } + + return { outQuadTensor, outConfTensor, outCountTensor }; +} + +std::vector quad_nms_from_adjacency( + torch::Tensor quads, torch::Tensor probs, torch::Tensor adjacency, + float probThreshold, float iouThreshold, + int64_t maxRegions) +{ + std::vector ret; + + AT_DISPATCH_FLOATING_TYPES( + quads.scalar_type(), + "quad_nms_from_adjacency", + ([&] { + ret = quad_nms_from_adjacency_impl( + quads.accessor(), + probs.accessor(), + adjacency.accessor(), + probThreshold, iouThreshold, + maxRegions + ); + }) + ); + + return ret; +} diff --git a/nemo-retriever-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu b/nemo-retriever-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu new file mode 100644 index 0000000000000000000000000000000000000000..dc015f5a42008ae16f40ca2b1754c27dd3fb38e0 --- /dev/null +++ b/nemo-retriever-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu @@ -0,0 +1,1720 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "non_maximal_suppression.h" + +#include +#include + +#include +#include +#include + +#include + +#include + +#include "../cuda_intellisense.cuh" +#include "../geometry.h" +#include "../common.h" +#include "../scope_timer.h" +#include "strided_quad.h" + +// If this flag is turned on, then a bunch of checks will be inserted to ensure that the same results are produced by +// successive calls to NMS. This means that it makes the library unusable outside of a debug context, so beware! +//#define NMS_VERIFY_CORRECTNESS + +namespace cg = cooperative_groups; +namespace ix = torch::indexing; + +inline +void print_tensor_stats2(const std::string &msg, const torch::Tensor& tensor) { + + auto fTensor = tensor.to(torch::kDouble).cpu(); + + std::stringstream ss; + if (tensor.numel() > 1) { + ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Max: " << fTensor.max().item() << " Min: " << fTensor.min().item() << " Mean: " << fTensor.mean().item() << " Std: " << fTensor.std().item(); + } + else if (tensor.numel() == 1) { + ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Value: " << fTensor.item() << std::endl; + } + else { + ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << std::endl; + } + std::cout << ss.str() << std::endl; +} + +inline +void print_tensor_vec_stats2(std::string msg, const std::vector& tensorVec) { + std::cout << msg << " Size: " << tensorVec.size() << std::endl; + std::stringstream ss; + msg = " - "; + for (int i = 0; i < tensorVec.size(); ++i) { + ss << msg << "[" << i << "]:"; + auto tensor = tensorVec[i]; + print_tensor_stats2(ss.str(), tensor); + ss.str(""); + } +} + +std::ostream &operator<<(std::ostream &os, dim3 d) +{ + return os << "(" << d.x << ", " << d.y << ", " << d.z << ")"; +} + +#define ADD_OP2(vector2_t) __device__ \ + vector2_t operator+(const vector2_t &a, const vector2_t &b) { \ + return { a.x + b.x, a.y + b.y }; \ + } +ADD_OP2(float2); +ADD_OP2(double2); +#undef ADD_OP2 + +#define ADD_OP4(vector4_t) __device__ \ + vector4_t operator+(const vector4_t &a, const vector4_t &b) { \ + return { a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w }; \ + } +ADD_OP4(float4); +ADD_OP4(double4); +#undef ADD_OP4 + +template +__device__ +std::array operator+(const std::array &a, const std::array &b) { + std::array ret; + #pragma unroll + for (size_t i = 0; i < Size; ++i) { + ret._Elems[i] = a._Elems[i] + b._Elems[i]; + } + return ret; +} + +#if __CUDA_ARCH__ >= 800 +#define __reduce_add_full_warp(val) __reduce_add_sync(0xFFFFFFFF, val) +#define __reduce_max_full_warp(val) __reduce_max_sync(0xFFFFFFFF, val) +#define __reduce_min_full_warp(val) __reduce_min_sync(0xFFFFFFFF, val) +#else +#define __reduce_add_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::plus()) +#define __reduce_max_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::greater()) +#define __reduce_min_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::less()) +#endif + +template +struct TToVec; +template<> +struct TToVec { typedef float2 type2; typedef float4 type4; }; +template<> +struct TToVec { typedef double2 type2; typedef double4 type4; }; + +template +__device__ +void write_embed_quad(accessor_t &acc, const MergeQuad_ &quad, int64_t storeOff) +{ + constexpr auto EMBED_QUAD_SIZE = sizeof(EmbedQuad_) / sizeof(T); + static_assert(EMBED_QUAD_SIZE == 10, "Unsupported embed quad size!"); + + const T *mergeBuff = reinterpret_cast(&quad); + + const T confidence = quad.Confidence; + const auto i = threadIdx.x; + + if (i >= 10) { + return; + } + + T outVal; + // Coordinates + if (i < 8) { + outVal = mergeBuff[i] / confidence; + // Confidence + } else if (i == 8) { + outVal = confidence / mergeBuff[9]; + // NumQuads + } else { + outVal = mergeBuff[9]; + } + + acc[i][storeOff] = outVal; +} + + +template +__device__ +void ordered_print(group_t &group, const char *const fmt, const Args& ...args) +{ + for (uint32_t i = 0; i < group.size(); ++i) { + if (group.thread_rank() == i) { + printf(fmt, args...); + } + group.sync(); + } +} + +template +__global__ +void device_row_collapse(torch::PackedTensorAccessor64 allQuads, + torch::PackedTensorAccessor64 allConfs, + T confThreshold, T iouThreshold, + torch::PackedTensorAccessor64 allOutCounts, + torch::PackedTensorAccessor64 allOutEmbedQuads +#ifdef NMS_VERIFY_CORRECTNESS + , torch::PackedTensorAccessor64 allOutIds +#endif + ) +{ + typedef InPlaceQuad_ Quadf; + static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!"); + + constexpr uint32_t ALL_MASK = 0xFFFFFFFF; + constexpr uint32_t WARP_SIZE = 32; + constexpr T MIN_VALID_AREA = 8; + + const uint32_t B = allQuads.size(0); + const uint32_t H = allQuads.size(1); + + const uint32_t b = blockIdx.z; + const uint32_t r = blockIdx.y * blockDim.y + threadIdx.y; + + if (r >= H) { + return; + } + + #define threadRank threadIdx.x + + auto rawQuads = reinterpret_cast(allQuads[b][r].data()); +#if defined(NDEBUG) + trove::coalesced_ptr quads(rawQuads); +#else + auto quads = rawQuads; +#endif + + auto confs = allConfs[b][r]; + + T conf = confs[threadRank]; + + bool quadValid = conf >= confThreshold; + uint32_t ballot = __ballot_sync(ALL_MASK, quadValid); + + // No valid quads in this window, so we're done! + if (ballot == 0) { + return; + } + + const Quadf currQuad = quads[threadRank]; + + const T qArea = currQuad.Area(); + + quadValid = quadValid && qArea > MIN_VALID_AREA; + ballot = __ballot_sync(ALL_MASK, quadValid); + if (ballot == 0) { + return; + } + if (! quadValid) { + conf = 0; + } + + MergeQuad_ qAccum{ZeroInitTag{}}; + + Quadf prevQuad; + auto pCurrQuad = reinterpret_cast(&currQuad); + auto pPrevQuad = reinterpret_cast(&prevQuad); + #pragma unroll + for (uint32_t i = 0; i < 8; ++i) { + pPrevQuad[i] = __shfl_up_sync(ALL_MASK, pCurrQuad[i], 1); + } + T prevConf = __shfl_up_sync(ALL_MASK, conf, 1); + + if (threadRank == 0) { + prevConf = 0; + } + + bool iouValid = false; + T iou = 0; + if (quadValid) { + qAccum.Append(currQuad, conf); + + if (prevConf >= confThreshold) { + iou = prevQuad.IOU_UpperBound(currQuad); + if (iou >= iouThreshold) { + iouValid = true; + } + } + } + + // This is the start of a span if the current confidence is above threshold, but the quad to the left is either below threshold, + // or the IOU between the quads is below threshold + const bool isStartOfSpan = quadValid && !iouValid; + + uint32_t label = isStartOfSpan; + // All labels start out as 0 or 1, and we'll then do a cumsum over the warp, which gives each thread an assigned label + // We also know that the final thread also contains the number of labels. + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + auto inc = __shfl_up_sync(ALL_MASK, label, offset); + if (threadRank >= offset) { + label += inc; + } + } + + // Before we zero out invalid labels, get the total number of labels + const uint32_t numLabels = __shfl_sync(ALL_MASK, label, WARP_SIZE - 1); + + // Zero out the label if the current quad isn't valid + label = quadValid ? label : 0; + + T* accumPtr = reinterpret_cast(&qAccum); + // Reduce all of the quads s.t. the left-most position in the span contains the full quad. + // We use `label` to decide whether to do the accumulation + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const auto otherLabel = __shfl_down_sync(ALL_MASK, label, offset); + + // Regardless of whether the labels match, all threads in the warp must make the shfl_down + // call. So we use factor to modulate whether the given merge is valid + const T factor = otherLabel == label && offset + threadRank < WARP_SIZE ? 1.0f : 0.0f; + + #pragma unroll + for (uint32_t i = 0; i < 10; ++i) { + accumPtr[i] += factor * __shfl_down_sync(ALL_MASK, accumPtr[i], offset); + } + } + + // Elect thread-0 to figure out where to store the results + uint32_t storeOff = 0; + if (threadRank == 0) { + storeOff = atomicAdd(&allOutCounts[b], numLabels); + } + // Broadcast that offset to the whole warp + storeOff = __shfl_sync(ALL_MASK, storeOff, 0); + + auto outEmbedQuads = allOutEmbedQuads[b]; + // Now write out each quad, but collectively + for (uint32_t procLabel = 1; procLabel <= numLabels; ++procLabel) { + // Discover the index of the start of each label span + ballot = __ballot_sync(ALL_MASK, procLabel == label); + // ffs will find the (1-based) index of the least significant bit in ballot. + // This just so happens to be the start of the span for the current label + uint32_t startIdx = __ffs(ballot) - 1; + + const T* inT = reinterpret_cast(&qAccum); + MergeQuad_ outQuad; + T* outT = reinterpret_cast(&outQuad); + #pragma unroll + for (uint32_t i = 0; i < 10; ++i) { + outT[i] = __shfl_sync(ALL_MASK, inT[i], startIdx); + } + + write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1); +#ifdef NMS_VERIFY_CORRECTNESS + if (threadRank == 0) { + allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx; + } +#endif + } + + if (threadRank == 0) { + // Increment the total number of quads by the number encountered on this row + atomicAdd(&allOutCounts[B], numLabels); + } + +#undef threadRank +} + +template +__global__ +void device_a2a_adjacency_sparse(const uint64_t punCounts, + T iouThreshold, + torch::PackedTensorAccessor64 embedQuads, + torch::PackedTensorAccessor64 outIsStart, + torch::PackedTensorAccessor64 outAdjCounts, + torch::PackedTensorAccessor64 outSparseAdj) +{ + const uint32_t b = blockIdx.y; + + const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast(punCounts)[b]; + + const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t row = jobIdx / quadCt; + const int32_t col = jobIdx % quadCt; + + // Only compute the upper triangular portion of the matrix + if (row >= quadCt || col < row) { + return; + } + + T* exData = IsSingleExample ? embedQuads.data() : embedQuads[b].data(); + + const auto qRow = StridedEmbedQuad_{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(), + qCol = StridedEmbedQuad_{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(); + + T pctRow, pctCol, iou; + thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol); + + auto warpGroup = cg::tiled_partition<32>(cg::this_thread_block()); + + auto rowGroup = cg::labeled_partition(warpGroup, row); + + const bool isValid = iou >= iouThreshold; + + const uint32_t ballot = rowGroup.ballot(isValid); + const uint32_t numValid = __popc(ballot); + + auto exAdjCounts = outAdjCounts[b].data(); + + int32_t storeOff = 0; + if (numValid > 0 && rowGroup.thread_rank() == 0) { + storeOff = atomicAdd(exAdjCounts + row, numValid); + } + storeOff = rowGroup.shfl(storeOff, 0); + + if (isValid) { + // This will set all of the bits to the left of this one to 1, otherwise 0. + // We can use this to count the number of bits that are set, and are less significant than this one, + // to get the local storage offset + uint32_t lowerMask = (1 << rowGroup.thread_rank()) - 1; + + storeOff += __popc(ballot & lowerMask); + + outSparseAdj[b][row][storeOff] = col; + if (row != col) { + // Because `col` gets merged into `row`, we mark it as inactive for reduction purposes. + // All of the quads that `col` is adjacent to will be absorbed by `row`. + outIsStart[b][col] = false; + + // Also store the transposed relation + storeOff = atomicAdd(exAdjCounts + col, 1); + outSparseAdj[b][col][storeOff] = row; + } + } else if (pctRow > 0.8f || pctCol > 0.8f) { + T anchorHeight = qRow.Height(); + T otherHeight = qCol.Height(); + + T ratio = anchorHeight > otherHeight ? + otherHeight / anchorHeight : + anchorHeight / otherHeight; + if (ratio > 0.9f) { + if (pctRow > 0.8f) { + // Other envelops anchor + outIsStart[b][row] = false; + } + else { + outIsStart[b][col] = false; + } + } + } +} + +template +__global__ +void device_a2a_adjacency_build_grid(const uint64_t punCounts, + torch::PackedTensorAccessor64 embedQuads, + torch::PackedTensorAccessor64 outGridCells, + torch::PackedTensorAccessor64 outQuadCells) +{ + constexpr T MIN_T = std::numeric_limits::min(); + constexpr T MAX_T = std::numeric_limits::max(); + constexpr uint32_t WARP_SIZE = 32; + constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE; + constexpr uint32_t FULL_WARP = 0xFFFFFFFF; + constexpr uint32_t FIRST_16_THREADS = 0x0FFFF; + constexpr T CELL_SIZE = I_CELL_SIZE; + constexpr T INV_CELL_SIZE = 1 / CELL_SIZE; + + const uint32_t b = blockIdx.z; + + const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast(punCounts)[b]; + const uint32_t quadIdx = blockIdx.y; + + if (!IsSingleExample && quadIdx >= quadCt) { + return; + } + + const uint32_t threadRank = threadIdx.x; + const uint32_t localThreadRank = threadRank & 0x1F; + + auto exQuads = embedQuads[b]; + + const uint32_t numCells[2] = { outGridCells.size(2), outGridCells.size(1) }; + + const uint32_t numRows = outGridCells.size(1); + const uint32_t numCols = outGridCells.size(2); + + // We use flip so that we can compute min and max simultaneously. + // First 4 threads compute the min, next 4 compute the max + T sign = localThreadRank < 8 ? 1.0f : -1.0f; + T myVal = sign * (localThreadRank < 16 ? exQuads[localThreadRank & 0x7][quadIdx] : MIN_T); + #pragma unroll + for (uint32_t offset = 2; offset < 8; offset <<= 1) { + T nextVal = __shfl_down_sync(FIRST_16_THREADS, myVal, offset); + myVal = min(myVal, nextVal); + } + const uint32_t cellVal = max(0.0f, sign * INV_CELL_SIZE * myVal); + + uint32_t minCell[2] = { __shfl_sync(FULL_WARP, cellVal, 0), __shfl_sync(FULL_WARP, cellVal, 1) }, + maxCell[2] = { __shfl_sync(FULL_WARP, cellVal, 8), __shfl_sync(FULL_WARP, cellVal, 9) }; + + #pragma unroll + for (uint32_t i = 0; i < 2; ++i) { + maxCell[i] = min(numCells[i] - 1, maxCell[i]); + } + + const uint32_t sizes[2] = { maxCell[0] - minCell[0] + 1, maxCell[1] - minCell[1] + 1 }; + + const uint32_t totalCells = sizes[0] * sizes[1]; + + auto exGridCells = outGridCells[b]; + + for (uint32_t i = threadRank; i < totalCells; i += BLOCK_SIZE) { + uint32_t row = minCell[1] + i / sizes[0]; + uint32_t col = minCell[0] + i % sizes[0]; + + int32_t *pCell = exGridCells[row][col].data(); + + // The first value in the array is the count, and the rest are the quad indices + int32_t storeOff = atomicAdd(pCell, 1) + 1; + pCell[storeOff] = quadIdx; + } + + if (threadRank < 2) { + outQuadCells[b][quadIdx][threadRank] = minCell[threadRank]; + } else if (threadRank < 4) { + outQuadCells[b][quadIdx][threadRank] = maxCell[threadRank - 2]; + } +} + +typedef uint8_t visit_mask_t; + +template +__global__ +void device_a2a_adjacency_with_grid(const uint64_t punCounts, + T iouThreshold, + torch::PackedTensorAccessor64 allEmbedQuads, + torch::PackedTensorAccessor64 allCells, + torch::PackedTensorAccessor64 allQuadExtents, + torch::PackedTensorAccessor64 outIsStart, + torch::PackedTensorAccessor64 outAdjCounts, + torch::PackedTensorAccessor64 outSparseAdj) +{ + constexpr T MIN_T = std::numeric_limits::min(); + constexpr T MAX_T = std::numeric_limits::max(); + constexpr uint32_t WARP_SIZE = 32; + constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE; + + const uint32_t b = blockIdx.z; + + const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast(punCounts)[b]; + const uint32_t quadIdx = blockIdx.y; + + if (!IsSingleExample && quadIdx >= quadCt) { + return; + } + + const uint32_t threadRank = threadIdx.x; + + auto exQuads = allEmbedQuads[b]; + + __shared__ T s_quadVerts[8]; + __shared__ uint32_t s_quadExtent[4]; + extern __shared__ uint32_t s_alreadyVisited[]; + + if (threadRank < 8) { + s_quadVerts[threadRank] = exQuads[threadRank][quadIdx]; + } else if (threadRank < 12) { + s_quadExtent[threadRank - 8] = reinterpret_cast(allQuadExtents[b][quadIdx].data())[threadRank - 8]; + } + + uint32_t zeroTerm = (quadCt + 31u) >> 5u; // Fast version of div_up(quadCt, 32) + for (uint32_t col = threadRank; col < zeroTerm; col += BLOCK_SIZE) { + s_alreadyVisited[col] = 0; + } + + __syncthreads(); + + auto exCells = allCells[b]; + auto exAdjCounts = reinterpret_cast(outAdjCounts[b].data()); + auto exAdjValues = outSparseAdj[b][quadIdx].data(); + + T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data(); + + const auto bdsAnchor = Quad_{ s_quadVerts }.Bounds(); + + const uint32_t startCol = s_quadExtent[0], + endCol = s_quadExtent[2]; + for (uint32_t row = s_quadExtent[1], endRow = s_quadExtent[3]; row <= endRow; ++row) { + auto rowCells = exCells[row]; + + for (uint32_t col = startCol; col <= endCol; ++col) { + auto colCells = reinterpret_cast(rowCells[col].data()); + + const uint32_t ct = colCells[0]; + + for (uint32_t i = threadRank + 1; i <= ct; i += BLOCK_SIZE) { + const uint32_t otherIdx = colCells[i]; + + const uint32_t maskIdx = otherIdx >> 5; // Divide by 32, since there are 32 bits per mask slot + const uint32_t maskBit = 1 << (otherIdx & 0x1F); // Set the relevant bit for this mask ID + + const bool alreadyVisited = atomicOr(s_alreadyVisited + maskIdx, maskBit) & maskBit; + + if (!alreadyVisited) { + const auto bdsOther = StridedEmbedQuad_{ exData + otherIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) }.Bounds(); + + T pctAnchor, pctOther, iou; + thrust::tie(pctAnchor, pctOther, iou) = geometry_region_sizes(bdsAnchor, bdsOther); + + if (iou >= iouThreshold) { + auto validGroup = cg::coalesced_threads(); + + uint32_t storeOff = 0; + if (validGroup.thread_rank() == 0) { + storeOff = atomicAdd(exAdjCounts + quadIdx, validGroup.size()); + } + storeOff = validGroup.shfl(storeOff, 0) + validGroup.thread_rank(); + + exAdjValues[storeOff] = otherIdx; + + if (otherIdx > quadIdx) { + outIsStart[b][otherIdx] = false; + } + } else if (pctAnchor > 0.8f || pctOther > 0.8f) { + T anchorHeight = bdsAnchor.Height(); + T otherHeight = bdsOther.Height(); + + T ratio = anchorHeight > otherHeight ? + otherHeight / anchorHeight : + anchorHeight / otherHeight; + if (ratio > 0.9f) { + if (pctAnchor > 0.8f) { + // Other envelops anchor + outIsStart[b][quadIdx] = false; + } else { + outIsStart[b][otherIdx] = false; + } + } + } + } + } + } + } +} + +template +__global__ +void device_flatten_graph_iterative(const uint64_t punCounts, + torch::PackedTensorAccessor64 allIsStart, + volatile uint32_t *allAdjCounts, + volatile uint32_t *allAdjValues +#ifdef NMS_VERIFY_CORRECTNESS + , int32_t *maxDepth +#endif + ) +{ + constexpr uint32_t WARP_SIZE = 32; + constexpr uint32_t VISIT_STACK_SIZE = 9; + constexpr uint32_t TERM_VALUE = std::numeric_limits::max(); + + constexpr visit_mask_t VISITED_MASK = 0b001; + constexpr visit_mask_t ADDED_MASK = 0b010; + constexpr visit_mask_t QUEUED_MASK = 0b100; + constexpr visit_mask_t QUEUED_OR_VISITED_MASK = VISITED_MASK | QUEUED_MASK; + + const uint32_t b = blockIdx.z; + const uint32_t anchorRow = blockIdx.y; + + const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast(punCounts)[b]; + + // Only need to check this if there are multiple examples, since in the case of a single example, + // the grid is precisely sized to that quadCt + if constexpr (!IsSingleExample) { + if (anchorRow >= quadCt) { + return; + } + } + + auto isStart = allIsStart[b].data(); + + const uint32_t threadRank = threadIdx.x; + + extern __shared__ visit_mask_t s_visitedMask[]; + +#ifndef NMS_VERIFY_CORRECTNESS + // Only need to process the anchor rows, since they're the only ones + // that will make it through the full NMS operation. + // NOTE: There's a race condition where some rows may be marked as anchor, + // but they'll later be marked non-anchor over the course of this kernel. + // That's fine. It's a bit of extra work, but there's no real way around it. + const bool anchorIsStart = isStart[anchorRow]; + if (!anchorIsStart) { + return; + } +#endif + + uint32_t *pIntVisitedMask = reinterpret_cast(s_visitedMask); + uint32_t zeroTerm = (quadCt + 3) >> 2; // Fast version of div_up(quadCt, 4) + for (uint32_t col = threadRank; col < zeroTerm; col += blockDim.x) { + pIntVisitedMask[col] = 0; + } + + __syncthreads(); + + const uint32_t maxExCount = allIsStart.size(1); + auto adjCounts = allAdjCounts + (b * maxExCount); + auto adjValues = allAdjValues + (b * maxExCount * maxExCount); + + auto adjAnchorValues = adjValues + (anchorRow * maxExCount); + // For the anchor row, set the visited mask to 0b10, which will signify that we haven't visited it yet, + // but that the value is already in the adjacency vector. + // 0bx1 signifies that the value has been visited + for (uint32_t i = threadRank, ct = adjCounts[anchorRow]; i < ct; i += blockDim.x) { + const auto adjCol = adjAnchorValues[i]; + s_visitedMask[adjCol] = ADDED_MASK; + } + + __syncthreads(); + + if (threadRank == 0) { + s_visitedMask[anchorRow] |= QUEUED_MASK; + } + + __syncthreads(); + + // TODO(mranzinger): Is it worth incorporating these other threads? + // It seems like the vast majority of adjacency counts is <32 + if (threadRank >= WARP_SIZE) { + return; + } + + uint32_t visitStack[VISIT_STACK_SIZE]; + visitStack[0] = TERM_VALUE; + visitStack[1] = anchorRow; +#ifndef NDEBUG + for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) { + visitStack[i] = -2; + } +#endif + int32_t visitPtr = 1; + + while (true) { +#ifdef NMS_VERIFY_CORRECTNESS + assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE); +#endif + const uint32_t threadNextCol = visitStack[visitPtr]; + const uint32_t warpNextCol = __reduce_min_full_warp(threadNextCol); + + // Check to see if this thread got chosen. + // If so, decrement the stack counter + if (threadNextCol == warpNextCol) { +#ifndef NDEBUG + // This makes it easier to debug where the pointer is + visitStack[visitPtr] = -2; +#endif + --visitPtr; + } + + // If the maximum value encountered is -1, that means that none of the threads + // had another value to process + if (warpNextCol == TERM_VALUE) { + break; + } + + const uint32_t procRow = warpNextCol; + + __syncthreads(); + + bool isAlreadyVisited = s_visitedMask[procRow] & VISITED_MASK; + + if (isAlreadyVisited) { + continue; + } + + const uint32_t procAdjCount = adjCounts[procRow]; + auto procAdjValues = adjValues + (procRow * maxExCount); + + // Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp. + // The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph + // has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically, + // the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage. + for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) { + const uint32_t adjCol = procAdjValues[i]; + + // This will set the queued flag for this column, if it's not already set. + // It also returns the old state. In our case, we only want to add this value to the + // stack iff it hasn't already been visited, and hasn't been queued elsewhere + // NOTE: CUDA doesn't support atomicOr on uint8_t :(, but it's not necessary that + // the operation be absolutely atomic, so the poor man's version is probably okay + const auto oldMask = s_visitedMask[adjCol]; + auto newMask = oldMask; + + bool alreadyAdded = oldMask & ADDED_MASK; + + auto group = cg::coalesced_threads(); + const uint32_t gThreadRank = group.thread_rank(); + uint32_t notAddedBallot = group.ballot(!alreadyAdded); + if (notAddedBallot) { + // Only one warp will ever be adding values to a given row, which means + // that we don't need atomics. However, other warps may be reading data + // from anchorRow, which means that we need to add the values first, + // followed by incrementing the count. This order makes things + // concurrency safe. + const uint32_t globalStoreOff = adjCounts[anchorRow]; + // Gets the count of the bits to the left of this thread + const uint32_t localStoreOff = __popc(notAddedBallot & ((1 << gThreadRank) - 1)); + + if (!alreadyAdded) { + adjAnchorValues[globalStoreOff + localStoreOff] = adjCol; + if (adjCol > anchorRow) { + // Also, ensure that this quad is no longer marked as a starting quad + isStart[adjCol] = false; + } + newMask |= ADDED_MASK; + } + + // Finally, commit the change by incrementing the counter + if (gThreadRank == 0) { + adjCounts[anchorRow] += __popc(notAddedBallot); + } + } + + bool alreadyHandled = oldMask & QUEUED_OR_VISITED_MASK; + + if (!alreadyHandled) { +#ifdef NMS_VERIFY_CORRECTNESS + newMask |= QUEUED_MASK; + ++visitPtr; + assert(visitPtr < VISIT_STACK_SIZE); + atomicMax(maxDepth, visitPtr); + visitStack[visitPtr] = adjCol; +#else + // Prefer potentially inconsistent results over buffer overflow + if (visitPtr < VISIT_STACK_SIZE - 1) { + newMask |= QUEUED_MASK; + ++visitPtr; + visitStack[visitPtr] = adjCol; + } +#endif + } + + if (newMask != oldMask) { + s_visitedMask[adjCol] = newMask; + } + } + + // We actually rely on the `pop_next` function largely to handle recursing down into the next row + __syncthreads(); + } +} + +void add_to_set(const torch::TensorAccessor& adjCounts, + const torch::TensorAccessor& adjValues, + int32_t row, + std::unordered_set& possible) +{ + if (possible.count(row)) { + return; + } + + possible.insert(row); + + const int32_t adjCount = adjCounts[row]; + auto values = adjValues[row].data(); + + for (int32_t i = 0; i < adjCount; ++i) { + const int32_t col = values[i]; + add_to_set(adjCounts, adjValues, col, possible); + } +} + +template +void cpu_flatten_graph(const uint64_t punCounts, + torch::Tensor isStartTensorGPU, + torch::Tensor adjCountsTensorGPU, + torch::Tensor adjValuesTensorGPU) +{ + auto isStartTensor = isStartTensorGPU.cpu(); + auto adjCountsTensor = adjCountsTensorGPU.cpu(); + auto adjValuesTensor = adjValuesTensorGPU.cpu(); + + auto allIsStart = isStartTensor.accessor(); + auto allAdjCounts = adjCountsTensor.accessor(); + auto allAdjValues = adjValuesTensor.accessor(); + + for (int32_t b = 0; b < allAdjCounts.size(0); ++b) { + const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast(punCounts)[b]; + + for (int32_t row = 0; row < quadCt; ++row) { + std::unordered_set fullAdjSet; + add_to_set(allAdjCounts[b], allAdjValues[b], row, fullAdjSet); + + int32_t &currCt = allAdjCounts[b][row]; + int32_t *currValues = allAdjValues[b][row].data(); + std::unordered_set existingSet{ currValues, currValues + currCt }; + + for (int32_t adjCol : fullAdjSet) { + if (existingSet.count(adjCol)) { + continue; + } + + currValues[currCt] = adjCol; + ++currCt; + + if (adjCol > row) { + allIsStart[b][adjCol] = false; + } + } + } + } + + isStartTensorGPU.copy_(isStartTensor); + adjCountsTensorGPU.copy_(adjCountsTensor); + adjValuesTensorGPU.copy_(adjValuesTensor); +} + + +__global__ +void device_a2a_adj_cleanup(const int32_t *counts, + torch::PackedTensorAccessor64 inOutAdjacency) +{ + const uint32_t b = blockIdx.y; + const uint32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t numQuads = counts[b]; + const uint32_t row = jobIdx / numQuads; + const uint32_t col = jobIdx % numQuads; + + if (row >= numQuads) { + return; + } + + auto adjacency = inOutAdjacency[b]; + + bool rowPivot = adjacency[row][row] > 0; + bool colPivot = adjacency[col][col] > 0; + + if (!rowPivot || !colPivot) { + adjacency[row][col] = 0; + } +} + +template +__global__ +void device_a2a_collapse(const uint64_t punCounts, + torch::PackedTensorAccessor64 allEmbedQuads, + torch::PackedTensorAccessor64 allIsLeadRow, + const int64_t *regionCounts, + torch::PackedTensorAccessor64 allAdjCounts, + torch::PackedTensorAccessor64 allAdjValues, + //torch::PackedTensorAccessor64 allOutPositions, + torch::PackedTensorAccessor64 outQuads, + T *outConf) +{ + constexpr uint32_t WARP_SIZE = 32; + constexpr uint32_t FULL_WARP = 0xFFFFFFFF; + constexpr uint32_t BLOCK_WIDTH = NumWarps * WARP_SIZE; + constexpr size_t MERGE_QUAD_SIZE = sizeof(MergeQuad_) / sizeof(T); + + static_assert(NumWarps < WARP_SIZE, "Only a single warp currently supported!"); + + const uint32_t b = blockIdx.z; + const uint32_t row = blockIdx.y; + + const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast(punCounts)[b]; + + if constexpr (!IsSingleExample) { + if (row >= quadCt) { + return; + } + } + + // Only process the lead rows + const auto isLeadRow = IsSingleExample ? allIsLeadRow.data() : allIsLeadRow[b].data(); + if (!isLeadRow[row]) { + return; + } + + const uint32_t threadRank = threadIdx.x; + const uint32_t localThreadRank = threadRank & 0x1F; + const uint32_t warpIdx = threadRank >> 5; + + __shared__ T s_mergeQuad[MERGE_QUAD_SIZE]; + + if constexpr (NumWarps > 1) { + if (threadRank < MERGE_QUAD_SIZE) { + s_mergeQuad[threadRank] = 0.0f; + } + + __syncthreads(); + } + + T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data(); + + const int32_t adjCount = allAdjCounts[b][row]; + const int32_t *adjIdxs = allAdjValues[b][row].data(); + + MergeQuad_ localMerge{ZeroInitTag{}}; + + for (int32_t i = threadRank; i < adjCount; i += BLOCK_WIDTH) { + const int32_t currQuadIdx = adjIdxs[i]; + const StridedEmbedQuad_ qCurr{ exData + currQuadIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) }; + + localMerge.Append(qCurr); + } + + T *mqV = reinterpret_cast(&localMerge); + #pragma unroll + for (uint32_t offset = 1; offset < WARP_SIZE; offset <<= 1) { + T mergeFactor = offset + localThreadRank < 32; + #pragma unroll + for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) { + mqV[i] += mergeFactor * __shfl_down_sync(FULL_WARP, mqV[i], offset); + } + } + #pragma unroll + for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) { + mqV[i] = __shfl_sync(FULL_WARP, mqV[i], 0); + } + + // Only need to do a multi-warp merge if there are enough quads to justify it + if (NumWarps > 1 && adjCount > WARP_SIZE) { + if (localThreadRank < MERGE_QUAD_SIZE) { + atomicAdd(s_mergeQuad + localThreadRank, mqV[localThreadRank]); + } + + __syncthreads(); + + mqV = s_mergeQuad; + } + + // Figure out the output position + uint32_t writePosition = 0; + if constexpr (!IsSingleExample) { + for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) { + writePosition += regionCounts[i]; + } + } + + const int32_t numLongs = row >> 3; // Divide by 8 + const uint8_t *pCurrIsLeadRow = reinterpret_cast(isLeadRow); + const uint64_t *lpCurrIsLeadRow = reinterpret_cast(pCurrIsLeadRow); + + for (int32_t i = threadRank; i < numLongs; i += BLOCK_WIDTH) { + writePosition += __popcll(lpCurrIsLeadRow[i]); + } + for (int32_t i = (numLongs * 8) + threadRank; i < row; i += BLOCK_WIDTH) { + if (pCurrIsLeadRow[i]) { + ++writePosition; + } + } + // Sum all of the individual offsets over the warp + writePosition = __reduce_add_full_warp(writePosition); + // Reduce across warps, if applicable + if constexpr (NumWarps > 1) { + __shared__ uint32_t s_threadWritePositions[NumWarps]; + if (localThreadRank == 0) { + s_threadWritePositions[warpIdx] = writePosition; + } + __syncthreads(); + writePosition = threadRank < NumWarps ? s_threadWritePositions[threadRank] : 0; + writePosition = __reduce_add_full_warp(writePosition); + } + + if (threadRank >= 9) { + return; + } + + const T sumConfidence = mqV[8]; + const T numQuads = mqV[9]; + const T divisor = threadRank < 8 ? sumConfidence : numQuads; + + const T myVal = mqV[threadRank] / divisor; + + auto writeVerts = outQuads[writePosition].data(); + + if (threadRank < 8) { + writeVerts[threadRank] = myVal; + } else { + outConf[writePosition] = myVal; + } +} + +struct CollapseRowsResult { + torch::Tensor ExCounts; + torch::Tensor StridedMergeQuads; + int32_t TotalNumQuads; + // NOTE: This will only be available in Debug builds + torch::Tensor QuadIds; + int32_t ImageWidth; + int32_t ImageHeight; +}; + +template +CollapseRowsResult collapse_rows( + torch::Tensor quads, torch::Tensor probs, scalar_t probThreshold, scalar_t iouThreshold +) +{ + if (! quads.is_contiguous()) { + throw std::runtime_error("Expected `quads` to be contiguous!"); + } + + if ((quads.size(2) % 32) != 0) { + throw std::runtime_error("Expected the width of the `quads` buffer to be a multiple of 32!"); + } + + int32_t imageWidth = quads.size(2) * 4; + int32_t imageHeight = quads.size(1) * 4; + + quads = quads.reshape({ quads.size(0), -1, 32, 4, 2 }); + probs = probs.reshape({ probs.size(0), -1, 32 }); + + if (quads.size(0) != probs.size(0) || quads.size(1) != probs.size(1)) { + throw std::runtime_error("Dimension mismatch between `quads` and `probs`"); + } + + // The final counter is for the total number of quads for the entire batch + auto counts = torch::zeros({ quads.size(0) + 1 }, quads.options().dtype(torch::kInt32)); + + int64_t embedSize = sizeof(EmbedQuad_) / sizeof(scalar_t); + auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options()); + +#ifdef NMS_VERIFY_CORRECTNESS + auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) }, + std::numeric_limits::max(), + counts.options().dtype(torch::kInt32)); +#else + torch::Tensor idsTensor; +#endif + + dim3 blockSize(32, 3, 1); + dim3 gridSize(1, + div_up(quads.size(1), blockSize.y), + quads.size(0)); + + device_row_collapse KERNEL_ARG2(gridSize, blockSize) ( + quads.packed_accessor64(), + probs.packed_accessor64(), + probThreshold, iouThreshold, + counts.packed_accessor64(), + rowMergeTensor.packed_accessor64() +#ifdef NMS_VERIFY_CORRECTNESS + , idsTensor.packed_accessor64() +#endif + ); + +#ifdef NMS_VERIFY_CORRECTNESS + static std::unordered_set s_quadIds; + auto cpuIdsTensor = idsTensor.cpu(); + const int32_t *idsPtr = cpuIdsTensor.data_ptr(); + if (s_quadIds.empty()) { + s_quadIds.insert(idsPtr, idsPtr + idsTensor.numel()); + } else { + std::unordered_set otherIds{ idsPtr, idsPtr + idsTensor.numel() }; + + if (s_quadIds != otherIds) { + throw std::runtime_error("Inconsistent Ids!"); + } + } +#endif + + // The final value in `counts` is actually to total number of quads for the entire batch + int32_t totalQuads = counts[-1].item(); + + counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1); + +#ifdef NMS_VERIFY_CORRECTNESS + int64_t maxExCount; + if (counts.size(0) > 1) { + maxExCount = counts.max().item(); + } else { + maxExCount = totalQuads; + } + + static bool s_sortOrder = false; + + rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount); + idsTensor = idsTensor.slice(1, 0, maxExCount); + auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder); s_sortOrder = !s_sortOrder; + + auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor); + + rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder); + idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order); +#endif + + return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight }; +} + + + +void verify_row(const torch::TensorAccessor &adjCounts, + const torch::TensorAccessor &adjValues, + int32_t row) +{ + // Traverse the graph, and accumulate all set flags across all rows marked + // adjacent by the current row. If the merge_up algorithm works correctly, then + // `possible` will contain exactly the same set of values as the current row + std::unordered_set possible; + add_to_set(adjCounts, adjValues, row, possible); + + std::unordered_set thisRow{ row }; + const int32_t thisCount = adjCounts[row]; + auto thisValues = adjValues[row].data(); + thisRow.insert(thisValues, thisValues + thisCount); + + if (thisRow != possible) { + throw std::runtime_error("The merge_up algorithm is not correct!"); + } +} + +struct AdjacencyResult { + // Shape: BxQ + // Specifies whether the given row is a result row + torch::Tensor IsLeadRow; + // Shape: BxQ + // The number of quads that need to be merged with the given quad + torch::Tensor AdjCounts; + // Shape: BxQx + // The indices of the adjacent quads. + torch::Tensor AdjValues; + int64_t MaxExCount; +}; + +template +void cpu_a2a_adjacency_sparse(const uint64_t punCounts, + const T iouThreshold, + torch::Tensor embedQuadsTensor, + torch::Tensor outIsStartTensorGPU, + torch::Tensor outAdjCountsTensorGPU, + torch::Tensor outSparseAdjTensorGPU) +{ + embedQuadsTensor = embedQuadsTensor.cpu(); + auto outIsStartTensor = outIsStartTensorGPU.cpu(); + auto outAdjCountsTensor = outAdjCountsTensorGPU.cpu(); + auto outSparseAdjTensor = outSparseAdjTensorGPU.cpu(); + + auto embedQuads = embedQuadsTensor.accessor(); + auto isStart = outIsStartTensor.accessor(); + auto adjCounts = outAdjCountsTensor.accessor(); + auto adjValues = outSparseAdjTensor.accessor(); + + for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) { + const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast(punCounts)[b]; + + T *exData = embedQuads[b].data(); + + for (int32_t row = 0; row < quadCt; ++row) { + const auto qRow = StridedEmbedQuad_{ exData + row, embedQuads.stride(1) }.Bounds(); + + for (int32_t col = 0; col < quadCt; ++col) { + const auto qCol = StridedEmbedQuad_{ exData + col, embedQuads.stride(1) }.Bounds(); + + T pctRow, pctCol, iou; + thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol); + + if (iou >= iouThreshold) { + int32_t &storeIdx = adjCounts[b][row]; + adjValues[b][row][storeIdx] = col; + ++storeIdx; + if (row < col) { + isStart[b][col] = false; + } + } else if (pctRow > 0.8f || pctCol > 0.8f) { + T anchorHeight = qRow.Height(); + T otherHeight = qCol.Height(); + + T ratio = anchorHeight > otherHeight ? + otherHeight / anchorHeight : + anchorHeight / otherHeight; + if (ratio > 0.9f) { + if (pctRow > 0.8f) { + // Other envelops anchor + isStart[b][row] = false; + } + else { + isStart[b][col] = false; + } + } + } + } + } + } + + outIsStartTensorGPU.copy_(outIsStartTensor); + outAdjCountsTensorGPU.copy_(outAdjCountsTensor); + outSparseAdjTensorGPU.copy_(outSparseAdjTensor); +} + +template +std::string to_flat_string(torch::Tensor tensor) { + tensor = tensor.flatten(); + + auto acc = tensor.accessor(); + + std::ostringstream oss; + oss << "["; + if (acc.size(0) > 0) { + oss << acc[0]; + for (int64_t i = 1; i < acc.size(0); ++i) { + oss << ", " << acc[i]; + } + } + oss << "]"; + return oss.str(); +} + +template +AdjacencyResult compute_all_to_all_adjacency( + const CollapseRowsResult &collapseResult, + scalar_t iouThreshold) +{ + torch::Tensor counts = collapseResult.ExCounts; + + int64_t maxExCount; + if (counts.size(0) > 1) { + maxExCount = counts.max().item(); + } else { + maxExCount = collapseResult.TotalNumQuads; + } + + auto isStartTensor = torch::ones({ counts.size(0), maxExCount }, counts.options().dtype(torch::kBool)); + auto adjCountsTensor = torch::zeros({ counts.size(0), maxExCount }, counts.options().dtype(torch::kInt32)); +#ifndef NMS_VERIFY_CORRECTNESS + auto adjValuesTensor = torch::empty({ counts.size(0), maxExCount, maxExCount }, counts.options().dtype(torch::kInt32)); +#else + auto adjValuesTensor = torch::full({ counts.size(0), maxExCount, maxExCount }, + 5000, + counts.options().dtype(torch::kInt32)); +#endif + + // If the batch is only a single example, instead of hitting global memory for the count, we can + // just encode the count into the pointer instead + uint64_t ptrCounts = reinterpret_cast(counts.data_ptr()); + if (counts.size(0) == 1) { + ptrCounts = maxExCount; + } + +#ifdef NMS_VERIFY_CORRECTNESS + auto cpuAdjValuesTensor = adjValuesTensor.cpu(); + auto cpuAdjCountsTensor = adjCountsTensor.cpu(); + auto cpuIsStartTensor = isStartTensor.cpu(); +#endif + + size_t smemSize; + dim3 gridSize, blockSize; + + /////////////////// + // NOTE(mranzinger): This algorithm uses a fixed sized grid to spatially subdivide the canvas. For virtually all test conditions + // I ran this through, it was slightly slower than the brute force approach that parallelizes better. + // It's possible that there is some number of words present (e.g. >500) where this algorithm becomes + // faster. + // + //constexpr int32_t CELL_SIZE = 100; + //constexpr int64_t NUM_BINS_PER_CELL = 200; + //int32_t numXCells = div_up(collapseResult.ImageWidth, CELL_SIZE); + //int32_t numYCells = div_up(collapseResult.ImageHeight, CELL_SIZE); + //auto gridCellsTensor = torch::zeros({ counts.size(0), numYCells, numXCells, NUM_BINS_PER_CELL }, adjCountsTensor.options()); + //auto quadCellExtentsTensor = torch::empty({ counts.size(0), maxExCount, 4 }, gridCellsTensor.options()); + //smemSize = div_up(static_cast(maxExCount), 32); + + //constexpr uint32_t GRID_NUM_WARPS = 3; + //blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 }; + //gridSize = dim3{ 1, static_cast(maxExCount), static_cast(counts.size(0)) }; + + //auto buildGridFn = counts.size(0) == 1 ? + // device_a2a_adjacency_build_grid : + // device_a2a_adjacency_build_grid; + + //buildGridFn KERNEL_ARG2(gridSize, blockSize) ( + // ptrCounts, + // collapseResult.StridedMergeQuads.packed_accessor64(), + // gridCellsTensor.packed_accessor64(), + // quadCellExtentsTensor.packed_accessor64() + //); + + //auto adjGridFn = counts.size(0) == 1 ? + // device_a2a_adjacency_with_grid : + // device_a2a_adjacency_with_grid; + + //adjGridFn KERNEL_ARG3(gridSize, blockSize, smemSize) ( + // ptrCounts, + // iouThreshold, + // collapseResult.StridedMergeQuads.packed_accessor64(), + // gridCellsTensor.packed_accessor64(), + // quadCellExtentsTensor.packed_accessor64(), + // isStartTensor.packed_accessor64(), + // adjCountsTensor.packed_accessor64(), + // adjValuesTensor.packed_accessor64() + //); + /////////////////// + + uint32_t totalWork = maxExCount * maxExCount; + + blockSize = dim3{96, 1}; + gridSize = dim3{div_up(totalWork, blockSize.x), + static_cast(counts.size(0))}; + + auto adjFn = counts.size(0) == 1 ? device_a2a_adjacency_sparse : device_a2a_adjacency_sparse; + + // This algorithm is O(n^2) with n being the current number of quads + adjFn KERNEL_ARG2(gridSize, blockSize) ( + ptrCounts, + iouThreshold, + collapseResult.StridedMergeQuads.packed_accessor64(), + isStartTensor.packed_accessor64(), + adjCountsTensor.packed_accessor64(), + adjValuesTensor.packed_accessor64() + ); + + +#ifdef NMS_VERIFY_CORRECTNESS + cpu_a2a_adjacency_sparse(ptrCounts, iouThreshold, + collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor); + + adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2)); + + assert(torch::all(cpuAdjCountsTensor == adjCountsTensor.cpu()).item()); + assert(torch::all(cpuIsStartTensor == isStartTensor.cpu()).item()); + assert(torch::all(cpuAdjValuesTensor == adjValuesTensor.cpu()).item()); + + std::cout << "\tA2A Is Start Count: " << isStartTensor.sum(torch::kInt32).item() + << ", Most Adjacent: " << adjCountsTensor.max().item() << std::endl; + + auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options()); +#endif + + auto traverseFn = counts.size(0) == 1 ? + device_flatten_graph_iterative : + device_flatten_graph_iterative; + + blockSize = dim3{ 128, 1, 1 }; + gridSize = dim3{ 1, static_cast(maxExCount), static_cast(counts.size(0)) }; + smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t); + + traverseFn KERNEL_ARG3(gridSize, blockSize, smemSize) ( + ptrCounts, + isStartTensor.packed_accessor64(), + reinterpret_cast(adjCountsTensor.data_ptr()), + reinterpret_cast(adjValuesTensor.data_ptr()) +#ifdef NMS_VERIFY_CORRECTNESS + , maxDepthTensor.data_ptr() +#endif + ); + +#ifdef NMS_VERIFY_CORRECTNESS + cpu_flatten_graph(ptrCounts, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor); + + cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2)); + adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2)); + + torch::Tensor diffStartIdxs = (cpuIsStartTensor != isStartTensor.cpu()).nonzero_numpy()[0]; + + assert(diffStartIdxs.numel() == 0); + + torch::Tensor diffCountIdxs = (cpuAdjCountsTensor != adjCountsTensor.cpu()).nonzero_numpy()[0]; + + assert(diffCountIdxs.numel() == 0); + + auto diffValuesTensor = torch::any(cpuAdjValuesTensor != adjValuesTensor.cpu(), /*dim=*/ 2, /*keepdim=*/ false).flatten().nonzero().flatten(); + + std::cout << "\t\tDiff Indices: " << to_flat_string(diffValuesTensor) << std::endl; + + auto cpuDiffCountsTensor = cpuAdjCountsTensor.flatten().index({ diffValuesTensor }); + auto cpuDiffRowsTensor = cpuAdjValuesTensor.flatten(0, 1).index({ diffValuesTensor }); + auto gpuDiffRowsTensor = adjValuesTensor.cpu().flatten(0, 1).index({ diffValuesTensor }); + + for (int64_t i = 0, ct = cpuDiffRowsTensor.size(0); i < ct; ++i) { + auto z = cpuDiffCountsTensor[i].item(); + auto diffRow = diffValuesTensor[i].item(); + std::cout << "\t\tRow " << diffRow << std::endl; + std::cout << "\t\t\tExpected: " << to_flat_string(cpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl; + std::cout << "\t\t\t GPU: " << to_flat_string(gpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl; + } + + assert(diffValuesTensor.size(0) == 0); + + std::cout << "\tA2A - Flatten - Is Start Count: " << isStartTensor.sum(torch::kInt32).item() + << ", Most Adjacent: " << adjCountsTensor.max().item() + << ", Max Depth: " << maxDepthTensor.item() << std::endl; + + cpuIsStartTensor = isStartTensor.cpu(); + cpuAdjCountsTensor = adjCountsTensor.cpu(); + cpuAdjValuesTensor = adjValuesTensor.cpu(); + auto cpuCounts = counts.cpu(); + auto cpuCollapseIds = collapseResult.QuadIds.cpu(); + + static std::vector> s_knownGroups; + static std::unordered_map> s_groupLookup; + + std::vector> idGroups; + decltype(s_groupLookup) groupLookup; + for (int64_t b = 0; b < counts.size(0); ++b) { + int64_t quadCt = cpuCounts[b].item(); + for (int64_t row = 0; row < quadCt; ++row) { + bool isLeadRow = cpuIsStartTensor[b][row].item(); + auto bCountsTensor = cpuAdjCountsTensor[b]; + auto bValuesTensor = cpuAdjValuesTensor[b]; + auto bCounts = bCountsTensor.accessor(); + auto bValues = bValuesTensor.accessor(); + + auto bIdsTensor = cpuCollapseIds[b]; + auto bIds = bIdsTensor.accessor(); + + std::unordered_set sIds; + for (int32_t i = 0, ct = bCounts[row]; i < ct; ++i) { + int32_t col = bValues[row][i]; + int32_t id = bIds[col]; + sIds.insert(id); + } + + if (sIds.empty()) { + throw std::runtime_error("The ids tensor is empty!"); + } + + groupLookup[bIds[row]] = sIds; + + if (isLeadRow) { + verify_row(bCounts, bValues, row); + idGroups.push_back(move(sIds)); + } + } + } + + if (s_knownGroups.empty()) { + s_knownGroups = move(idGroups); + s_groupLookup = move(groupLookup); + } else { + // Make a copy + auto remOrigGroups = s_knownGroups; + auto remOrigGroupLookup = s_groupLookup; + + std::vector quadIds; + for (auto &kv : remOrigGroupLookup) { + quadIds.push_back(kv.first); + } + for (int32_t qId : quadIds) { + assert(groupLookup.count(qId)); + } + assert(groupLookup.size() == remOrigGroupLookup.size()); + + for (int32_t qId : quadIds) { + auto &oldGroup = remOrigGroupLookup[qId]; + auto &newGroup = groupLookup[qId]; + + if (oldGroup == newGroup) { + remOrigGroupLookup.erase(qId); + groupLookup.erase(qId); + } else { + throw std::runtime_error("Group mismatch!"); + } + } + + for (int i = idGroups.size() - 1; i >= 0; --i) { + for (int j = remOrigGroups.size() - 1; j >= 0; --j) { + auto &idGroup = idGroups[i]; + auto &knownGroup = remOrigGroups[j]; + + if (idGroup == knownGroup) { + idGroups.erase(begin(idGroups) + i); + remOrigGroups.erase(begin(remOrigGroups) + j); + break; + } + } + } + + if (!idGroups.empty() || !remOrigGroups.empty()) { + auto group_str = [] (auto &group) { + std::vector vGroup{ std::begin(group), std::end(group) }; + std::sort(std::begin(vGroup), std::end(vGroup)); + + auto id_str = [] (int32_t id) { + std::ostringstream oss; + //oss << "(" << (id / 32) << ", " << (id % 32) << ")"; + oss << id; + return oss.str(); + }; + + std::ostringstream oss; + oss << "[" << id_str(vGroup[0]); + for (size_t i = 1; i < vGroup.size(); ++i) { + oss << ", " << id_str(vGroup[i]); + } + oss << "]"; + return oss.str(); + }; + + std::cout << "\tEncountered a difference in groups!" << std::endl + << "\t\tOrig groups:" << std::endl; + for (auto &group : remOrigGroups) { + std::cout << "\t\t\t" << group_str(group) << std::endl; + } + std::cout << "\t\tNew groups:" << std::endl; + for (auto &group : idGroups) { + std::cout << "\t\t\t" << group_str(group) << std::endl; + } + } + } +#endif + + return { isStartTensor, adjCountsTensor, adjValuesTensor, maxExCount }; +} + + + +template +nms_result_t + all_to_all_collapse( + const CollapseRowsResult &collapseRowsRes, + const AdjacencyResult &adjResult) +{ + auto counts = collapseRowsRes.ExCounts; + auto embedQuads = collapseRowsRes.StridedMergeQuads; + + if (!embedQuads.is_contiguous()) { + throw std::runtime_error("Input embed quads were not contiguous!"); + } + + torch::Tensor isLeadRow; + if (counts.size(0) == 1) { + isLeadRow = adjResult.IsLeadRow; + } else { + // For multiple examples: IsLeadRow will have true values beyond the extent of the number of quads + // However, we know that Counts > 0 only happen within the extent, so the set intersection + // tells us which rows are actually lead + isLeadRow = torch::logical_and(adjResult.IsLeadRow, adjResult.AdjCounts > 0); + } + + auto regionCounts = isLeadRow.sum(/*dim=*/ 1, /*keepdim=*/ false, torch::kInt64); + + const int64_t numOutQuads = counts.size(0) == 1 ? regionCounts.item() : regionCounts.sum().item(); + + constexpr int32_t NUM_WARPS = 4; + dim3 blockSize(NUM_WARPS * 32, 1, 1); + dim3 gridSize(1, adjResult.MaxExCount, counts.size(0)); + + // If the batch is only a single example, instead of hitting global memory for the count, we can + // just encode the count into the pointer instead + uint64_t ptrCounts = reinterpret_cast(counts.data_ptr()); + if (counts.size(0) == 1) { + ptrCounts = adjResult.MaxExCount; + } + + torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options()); + torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options()); + + auto collapseFn = counts.size(0) == 1 ? + device_a2a_collapse : + device_a2a_collapse; + + collapseFn KERNEL_ARG2(gridSize, blockSize) ( + ptrCounts, + embedQuads.packed_accessor64(), + isLeadRow.packed_accessor64(), + regionCounts.data_ptr(), + adjResult.AdjCounts.packed_accessor64(), + adjResult.AdjValues.packed_accessor64(), + outQuads.packed_accessor64(), + outConf.data_ptr() + ); + + return { outQuads, outConf, regionCounts }; +} + +template +nms_result_t cuda_quad_non_maximal_suppression_impl( + torch::Tensor quads, torch::Tensor probs, + scalar_t probThreshold, scalar_t iouThreshold, + int64_t maxRegions, bool verbose) +{ + static const bool s_timerEnabled = true; + static const bool s_verboseLevel2 = true; + + // Make sure there's a batch dimension + if (quads.dim() == 4) { + // B,H,W,V,2 + quads = quads.unsqueeze(0); + // B,H,W + probs = probs.unsqueeze(0); + } + + //print_tensor_vec_stats2("NMS Input (quads, probs): ", { quads, probs }); + + double msRowCollapse = -1, + msAdjacency = -1, + msA2ACollapse = -1, + msTotal = -1; + + CollapseRowsResult collapseRows; + AdjacencyResult adjacency; + torch::Tensor retQuads, retConf, regionCounts; + + { + CudaStoreTimer tTotal{msTotal, s_timerEnabled}; + { + CudaStoreTimer t{msRowCollapse, s_timerEnabled && verbose && s_verboseLevel2}; + + // First combine all of the quads in each row + collapseRows = collapse_rows(quads, probs, probThreshold, iouThreshold); + + if (collapseRows.TotalNumQuads == 0) { + return { + torch::empty({ 0, 4, 2 }, quads.options()), + torch::empty({ 0 }, probs.options()), + collapseRows.ExCounts.toType(torch::kInt64) + }; + } + } + { + CudaStoreTimer t{msAdjacency, s_timerEnabled && verbose && s_verboseLevel2}; + adjacency = compute_all_to_all_adjacency(collapseRows, iouThreshold); + } + { + CudaStoreTimer t{msA2ACollapse, s_timerEnabled && verbose && s_verboseLevel2}; + std::tie(retQuads, retConf, regionCounts) = all_to_all_collapse(collapseRows, adjacency); + } + } + +#ifndef NDEBUG + assert(regionCounts.sum().item() == retQuads.size(0)); +#endif + + //print_tensor_vec_stats2(" Full NMS (quads, conf, counts): ", { retQuads, retConf, retCounts }); + + if (s_timerEnabled && verbose) { + std::cout << "NMS Cuda " << retQuads.size(0) + << " - Row Collapse (" << quads.size(0) << ", " << quads.size(1) << ", " << quads.size(2) << ") - (" << collapseRows.TotalNumQuads << "): " << msRowCollapse << "ms" + << ", Adjacency (" << adjacency.AdjCounts.sum(torch::kInt32).item() << "): " << msAdjacency << "ms" + << ", A2A Collapse (" << retQuads.size(0) << "): " << msA2ACollapse << "ms" + << ", Total: " << msTotal << "ms" + << std::endl; + } + + return { retQuads, retConf, regionCounts }; +} + +nms_result_t cuda_quad_non_maximal_suppression( + torch::Tensor quads, torch::Tensor probs, + float probThreshold, float iouThreshold, + int64_t kernelHeight, int64_t kernelWidth, + int64_t maxRegions, bool verbose) +{ + nms_result_t ret; + + ret = cuda_quad_non_maximal_suppression_impl( + quads.toType(torch::kFloat32), probs.toType(torch::kFloat32), + probThreshold, iouThreshold, + maxRegions, verbose + ); + + // AT_DISPATCH_FLOATING_TYPES_AND_HALF( + // quads.scalar_type(), + // "cuda_quad_non_maximal_suppression_impl", + // ([&] { + // ret = cuda_quad_non_maximal_suppression_impl( + // move(quads), move(probs), + // probThreshold, iouThreshold, + // maxRegions + // ); + // }) + // ); + + return ret; +} diff --git a/nemo-retriever-ocr/cpp/non_maximal_suppression/nms_common.h b/nemo-retriever-ocr/cpp/non_maximal_suppression/nms_common.h new file mode 100644 index 0000000000000000000000000000000000000000..9312bd428ad9590f84472b364e09cae0e00717ea --- /dev/null +++ b/nemo-retriever-ocr/cpp/non_maximal_suppression/nms_common.h @@ -0,0 +1,227 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "../geometry.h" +#include "../cuda_intellisense.cuh" +#include "strided_quad.h" + + + +std::vector quad_nms_from_adjacency( + torch::Tensor quads, torch::Tensor probs, torch::Tensor adjacency, + float probThreshold, float iouThreshold, + int64_t maxRegions); + +template +struct EmbedQuad_ : public QuadBase_ > { + Point_ Vertices[4]; + T Confidence; + T NumQuads = 0; + + __device__ + EmbedQuad_(T confidence = 0) + { + Reset(); + Confidence = confidence; + } + __device__ + EmbedQuad_(const EmbedQuad_ &other) = default; + + __device__ + void swap(EmbedQuad_ &other) noexcept { + using std::swap; + + for (size_t i = 0; i < 4; ++i) { + swap(Vertices[i], other.Vertices[i]); + } + + SWAP(Confidence, other.Confidence); + SWAP(NumQuads, other.NumQuads); + } + + __device__ + EmbedQuad_(EmbedQuad_ &&other) : EmbedQuad_() { + other.swap(*this); + } + + __device__ + EmbedQuad_ &operator=(EmbedQuad_ other) { + other.swap(*this); + return *this; + } + + __device__ + void Append(const EmbedQuad_ &other) { + Append(other, other.Confidence, other.NumQuads); + } + + template + __device__ + void Append(const QuadBase_ &q, T conf, T numQuads = 1) { + Confidence *= NumQuads; + + if (Confidence > 0) { + for (size_t i = 0; i < 4; ++i) { + Vertices[i] *= Confidence; + } + } + + Confidence += conf * numQuads; + + auto qVertices = static_cast(&q)->Vertices; + for (size_t i = 0; i < 4; ++i) { + Vertices[i] += conf * numQuads * qVertices[i]; + Vertices[i] /= Confidence; + } + + NumQuads += numQuads; + Confidence /= NumQuads; + } + + __device__ + void Prepare() { + // T factor = 1.0 / Confidence; + // for (size_t i = 0; i < 4; ++i) { + // Vertices[i] *= factor; + // } + // Confidence /= numQuads; + } + + __device__ + void Reset() { + for (size_t i = 0; i < 4; ++i) { + Vertices[i] = Point_{0, 0}; + } + Confidence = 0.0f; + NumQuads = 0; + } + + __device__ + const Point_ &operator[](size_t v) const { return Vertices[v]; } + __device__ + Point_ &operator[](size_t v) { return Vertices[v]; } +}; + +struct ZeroInitTag {}; + +template +struct MergeQuad_ : public QuadBase_> { + Point_ Vertices[4]; + T Confidence; + T NumQuads; + + MergeQuad_() = default; + + __device__ + MergeQuad_(ZeroInitTag) : Confidence(0), NumQuads(0) { + for (size_t i = 0; i < 4; ++i) { + Vertices[i] = Point_{0, 0}; + } + } + + template + __device__ + void Append(const QuadBase_ &q, T conf) { + Confidence += conf; + ++NumQuads; + + auto &d = static_cast(q); + for (size_t i = 0; i < 4; ++i) { + Vertices[i] += conf * d[i]; + } + } + __device__ + void Append(const EmbedQuad_ &q) { + T qConf = q.NumQuads * q.Confidence; + + Confidence += qConf; + NumQuads += q.NumQuads; + for (size_t i = 0; i < 4; ++i) { + Vertices[i] += qConf * q.Vertices[i]; + } + } + __device__ + void Append(const StridedEmbedQuad_ &q) { + const T numQuads = q.NumQuads(); + const T qConf = numQuads * q.Confidence(); + + Confidence += qConf; + NumQuads += numQuads; + for (size_t i = 0; i < 4; ++i) { + Vertices[i] += qConf * q[i]; + } + } + + __device__ + EmbedQuad_ Commit() { + EmbedQuad_ ret; + for (size_t i = 0; i < 4; ++i) { + ret.Vertices[i] = Vertices[i] / Confidence; + } + ret.Confidence = Confidence / NumQuads; + ret.NumQuads = NumQuads; + + return ret; + } + + __device__ + const Point_ &operator[](size_t v) const { return Vertices[v]; } + __device__ + Point_ &operator[](size_t v) { return Vertices[v]; } +}; + +template +__device__ +inline T triangle_root(T val) +{ + // It's easier to visualize this algorithm for a lower triangular matrix + // What we're trying to find is the `row` of a lower triangular matrix that a given `val` resides in. + // e.g. 0->0, 2->1, 4->2, etc. + // + // 0: 0 + // 1: 1 2 + // 2: 3 4 5 + // 3: 6 7 8 9 + // + // See https://math.stackexchange.com/questions/698961/finding-the-triangular-root-of-a-number for explanation + Intermediate numer = Intermediate(-1) + sqrt(Intermediate(1) + Intermediate(8) * Intermediate(val)); + Intermediate denom = Intermediate(2); + + Intermediate ret = floor(numer / denom); + return T(ret); +} + +template +void visit_node(const std::vector> &allQuads, size_t quadIdx, + const std::vector> &adjIdxs, EmbedQuad_ &currQuad, + std::unordered_set &visited) +{ + if (visited.count(quadIdx) > 0) return; + + const EmbedQuad_ &vQuad = allQuads[quadIdx]; + + currQuad.Append(vQuad); + visited.insert(quadIdx); + + for (size_t childIdx : adjIdxs[quadIdx]) { + visit_node(allQuads, childIdx, adjIdxs, currQuad, visited); + } +} + +template +void copy_quad(const QuadBase_ &srcQuad, scalar_t *pDest) +{ + auto vertices = static_cast(&srcQuad)->Vertices; + for (size_t i = 0; i < 4; ++i) { + const Point_ &v = vertices[i]; + *pDest++ = v.X; + *pDest++ = v.Y; + } +} diff --git a/nemo-retriever-ocr/cpp/non_maximal_suppression/nms_kd_tree.h b/nemo-retriever-ocr/cpp/non_maximal_suppression/nms_kd_tree.h new file mode 100644 index 0000000000000000000000000000000000000000..b93d53c531572e9e53d2001836e85199ea8319b6 --- /dev/null +++ b/nemo-retriever-ocr/cpp/non_maximal_suppression/nms_kd_tree.h @@ -0,0 +1,449 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "../geometry.h" + + +#define MODE_GEOMETRY 0x02ull +#define MODE_CHILDREN 0x00ull + +#define DIM_X 0x0ull +#define DIM_Y 0x1ull + +static const size_t INVALID_IDX = -1; + +template +struct NMS_BoundsWrapper +{ + typedef std::unique_ptr Ptr; + typedef AABB_ bds_t; + + size_t GeoIdx; + const T *Geometry; + bds_t Bounds; + + NMS_BoundsWrapper(size_t geoIdx, const T *geometry) : GeoIdx(geoIdx), Geometry(geometry), Bounds(geometry->Bounds()) { } +}; + +template +class NMS_NodeAllocator; + +template +class NMS_KDTree; + +template +class NMS_BuildCache; + +template +class NMS_KDNode +{ + friend class NMS_KDTree; + +public: + typedef NMS_BoundsWrapper bds_t; + typedef std::unique_ptr UPtr; + typedef typename T::inner_type inner_type; + typedef std::vector geo_vec_t; + typedef std::unique_ptr geo_vec_ptr; + + void Build(geo_vec_ptr geometries, const typename bds_t::bds_t &envelope, + NMS_NodeAllocator &allocator, NMS_BuildCache &buildCache); + + template + void FindIntersections(size_t geoIdx, const typename bds_t::bds_t &bds, const Fn &fn) const; + +private: + inline uintptr_t Dim() const { return reinterpret_cast(m_ptr) & 0x01ull; } + inline uintptr_t Mode() const { return reinterpret_cast(m_ptr) & 0x02ull; } + inline void Children(NMS_KDNode *&children, inner_type &splitPos) const + { + auto vPtr = Geometries(); + splitPos = *reinterpret_cast(vPtr); + children = reinterpret_cast(vPtr + sizeof(inner_type)); + } + inline uint8_t* Geometries() const + { + return reinterpret_cast(reinterpret_cast(m_ptr) & ~0x3ull); + } + inline void SetPtr(uint8_t *vPtr, uintptr_t mode, uintptr_t dim) + { + m_ptr = reinterpret_cast( + reinterpret_cast(vPtr) | mode | dim + ); + } + void AssignGeometries(geo_vec_ptr geometries, NMS_BuildCache &buildCache); + + uint8_t *m_ptr; +}; + +template +class NMS_NodeAllocator +{ +public: + typedef NMS_KDNode node_t; + typedef typename node_t::inner_type inner_type; + + NMS_NodeAllocator(size_t initialGuess = 512); + ~NMS_NodeAllocator(); + + void Get(size_t numNodes, NMS_KDNode *&outNodes, inner_type *&outSplitPos, uint8_t *&outRawPtr); + +private: + std::vector> m_buffers; + size_t m_offset; +}; + +template +class NMS_BuildCache +{ +public: + typedef typename NMS_KDNode::bds_t bds_t; + typedef std::unique_ptr Ptr; + typedef std::vector geo_vec_t; + typedef std::unique_ptr geo_vec_ptr; + + NMS_BuildCache(size_t initialSize); + ~NMS_BuildCache(); + + geo_vec_ptr Get(size_t sizeHint); + bds_t** GetRawBuffer(size_t numGeos, uint8_t *&rawPtr); + + void Release(geo_vec_ptr buff); + +private: + std::stack m_cache; + std::vector> m_rawBuffers; + size_t m_rawOffset; +}; + + +template +class NMS_KDTree +{ + typedef typename T::inner_type inner_type; + typedef NMS_BoundsWrapper bds_t; + typedef NMS_KDNode node_t; + +public: + NMS_KDTree(); + ~NMS_KDTree(); + + void Build(const std::vector &geometries); + + template + void FindIntersections(size_t geoIdx, const Fn &fn) const; + + template + void FindIntersections(const T &geo, const Fn &fn) const; + +private: + bds_t *m_wrappers; + NMS_NodeAllocator m_allocator; + node_t m_root; + typename NMS_BuildCache::Ptr m_buildCache; +}; + +template +NMS_KDTree::NMS_KDTree() + : m_wrappers(nullptr) +{ + m_root.m_ptr = nullptr; +} + +template +NMS_KDTree::~NMS_KDTree() +{ + free(m_wrappers); +} + +template +void NMS_KDTree::Build(const std::vector &geometries) +{ + if (geometries.empty()) { + m_root.m_ptr = nullptr; + return; + } + + // Doing this so that we can perform placement-new on the array buffer, and thus + // can only perform a single memory allocation for all geometries at once + m_wrappers = reinterpret_cast(malloc(sizeof(bds_t) * geometries.size())); + + m_buildCache.reset(new NMS_BuildCache(geometries.size())); + + auto bdsGeos = m_buildCache->Get(geometries.size()); + + typename bds_t::bds_t envelope; + + for (size_t i = 0; i < geometries.size(); ++i) { + // Placement new. Constructs the object in the place specified in the first (...) + new (m_wrappers + i) bds_t(i, &geometries[i]); + + bdsGeos->push_back(m_wrappers + i); + if (i == 0) { + envelope = m_wrappers[i].Bounds; + } else { + envelope = envelope.Union(m_wrappers[i].Bounds); + } + } + + + m_root.Build(std::move(bdsGeos), envelope, m_allocator, *m_buildCache); +} + +template +void NMS_KDNode::Build(geo_vec_ptr geometries, const typename bds_t::bds_t &envelope, + NMS_NodeAllocator &allocator, NMS_BuildCache &buildCache) +{ + static const size_t MAX_GEOMETRIES = 8; + + if (geometries->size() <= MAX_GEOMETRIES) { + AssignGeometries(std::move(geometries), buildCache); + } else { + geo_vec_ptr leftGeos = buildCache.Get(geometries->size()), + rightGeos = buildCache.Get(geometries->size()); + + inner_type szX = envelope[2] - envelope[0]; + inner_type szY = envelope[3] - envelope[1]; + + int64_t dim = szX > szY ? 0 : 1; + auto emn = envelope[dim]; + auto emx = envelope[dim + 2]; + + auto pivotPos = (emn + emx) / 2; + for (bds_t *g : *geometries) { + auto mn = g->Bounds[dim]; + auto mx = g->Bounds[dim + 2]; + + if (mn < pivotPos) { + leftGeos->push_back(g); + } + if (mx > pivotPos) { + rightGeos->push_back(g); + } + } + + if (leftGeos->size() == geometries->size() || rightGeos->size() == geometries->size()) { + AssignGeometries(std::move(geometries), buildCache); + buildCache.Release(std::move(leftGeos)); + buildCache.Release(std::move(rightGeos)); + } else { + buildCache.Release(std::move(geometries)); + + inner_type *nodeSplitPos; + uint8_t *nodeRawPtr; + NMS_KDNode *children; + allocator.Get(2, children, nodeSplitPos, nodeRawPtr); + + SetPtr(nodeRawPtr, MODE_CHILDREN, dim); + *nodeSplitPos = pivotPos; + + typename bds_t::bds_t leftEnv{envelope}, rightEnv{envelope}; + // Set the max of the left envelope to the split plane + leftEnv[dim + 2] = pivotPos; + // Set the min of the right envelope to the split plane + rightEnv[dim] = pivotPos; + + children[0].Build(std::move(leftGeos), leftEnv, allocator, buildCache); + children[1].Build(std::move(rightGeos), rightEnv, allocator, buildCache); + } + } +} + +template +void NMS_KDNode::AssignGeometries(geo_vec_ptr geometries, NMS_BuildCache &buildCache) +{ + if (geometries->empty()) { + SetPtr(nullptr, MODE_GEOMETRY, 0); + } else { + uint8_t *vPtr; + bds_t **geoPtr = buildCache.GetRawBuffer(geometries->size(), vPtr); + std::copy(geometries->begin(), geometries->end(), geoPtr); + + SetPtr(vPtr, MODE_GEOMETRY, 0); + } + buildCache.Release(std::move(geometries)); +} + +template +template +void NMS_KDTree::FindIntersections(size_t geoIdx, const Fn &fn) const +{ + if (!m_wrappers) return; + + auto &bds = m_wrappers[geoIdx].Bounds; + + m_root.FindIntersections(geoIdx, bds, fn); +} + +template +template +void NMS_KDTree::FindIntersections(const T &geo, const Fn &fn) const +{ + if (!m_wrappers) return; + + NMS_BoundsWrapper bdsWrapper(INVALID_IDX, &geo); + + m_root.FindIntersections(INVALID_IDX, bdsWrapper.Bounds, fn); +} + +template +template +void NMS_KDNode::FindIntersections(size_t geoIdx, const typename bds_t::bds_t &bds, const Fn &fn) const +{ + auto mode = Mode(); + + if (mode == MODE_GEOMETRY) { + auto *vPtr = Geometries(); + + size_t numGeos = *reinterpret_cast(vPtr); + bds_t **geoPtr = reinterpret_cast(vPtr + sizeof(size_t)); + + bds_t **endPtr = geoPtr + numGeos; + for (; geoPtr != endPtr; ++geoPtr) { + const bds_t *child = *geoPtr; + + // Don't compute this against self + if (geoIdx != INVALID_IDX && child->GeoIdx <= geoIdx) continue; + + typename bds_t::bds_t::inner_type pctN, pctM, iou; + std::tie(pctN, pctM, iou) = geometry_region_sizes(bds, child->Bounds); + + if (iou > 0) { + fn(child->GeoIdx, pctN, pctM, iou); + } + } + } else { + auto dim = Dim(); + + auto mn = bds[dim]; + auto mx = bds[dim + 2]; + + NMS_KDNode *children; + inner_type splitPos; + Children(children, splitPos); + + if (mn < splitPos) { + children[0].FindIntersections(geoIdx, bds, fn); + } + if (mx > splitPos) { + children[1].FindIntersections(geoIdx, bds, fn); + } + } +} + +template +NMS_NodeAllocator::NMS_NodeAllocator(size_t initialGuess) + : m_offset(0) +{ + size_t allocSize = initialGuess * (sizeof(inner_type) + 2 * sizeof(node_t)); + auto ptr = reinterpret_cast(malloc(allocSize)); + m_buffers.emplace_back(initialGuess, ptr); +} + +template +NMS_NodeAllocator::~NMS_NodeAllocator() +{ + for (auto &p : m_buffers) { + free(p.second); + } +} + +template +void NMS_NodeAllocator::Get(size_t numNodes, node_t *&outNodes, inner_type *&outSplitPos, uint8_t *&outRawPtr) +{ + auto &currBuff = m_buffers.back(); + + size_t rem = currBuff.first - m_offset; + + size_t reqSize = sizeof(inner_type) + sizeof(node_t) * numNodes; + + if (rem >= reqSize) { + outRawPtr = currBuff.second + m_offset; + outSplitPos = reinterpret_cast(outRawPtr); + outNodes = reinterpret_cast(outRawPtr + sizeof(inner_type)); + m_offset += reqSize; + return; + } + + // Rounds up to the nearest factor of 2 + size_t allocSize = (std::max(currBuff.first * 2, reqSize) + 1) & ~0x01ull; + auto ptr = reinterpret_cast(malloc(allocSize)); + m_buffers.emplace_back(allocSize, ptr); + m_offset = 0; + + Get(numNodes, outNodes, outSplitPos, outRawPtr); +} + +template +NMS_BuildCache::NMS_BuildCache(size_t initialSize) + : m_rawOffset(0) +{ + auto allocSize = sizeof(bds_t*) * initialSize * 2; + auto raw1 = reinterpret_cast(malloc(allocSize)); + m_rawBuffers.emplace_back(allocSize, raw1); +} + +template +NMS_BuildCache::~NMS_BuildCache() +{ + for (auto &p : m_rawBuffers) { + free(p.second); + } +} + +template +typename NMS_BuildCache::geo_vec_ptr NMS_BuildCache::Get(size_t sizeHint) +{ + geo_vec_ptr ret; + if (! m_cache.empty()) { + ret = std::move(m_cache.top()); + m_cache.pop(); + ret->clear(); + } else { + ret.reset(new std::vector); + } + + ret->reserve(sizeHint); + return ret; +} + +template +typename NMS_BuildCache::bds_t** NMS_BuildCache::GetRawBuffer(size_t numGeos, uint8_t *&rawPtr) +{ + auto &currBuff = m_rawBuffers.back(); + size_t rem = currBuff.first - m_rawOffset; + + size_t reqSize = sizeof(size_t) + sizeof(bds_t*) * numGeos; + + if (rem >= reqSize) { + rawPtr = currBuff.second + m_rawOffset; + m_rawOffset += reqSize; + reinterpret_cast(rawPtr)[0] = numGeos; + return reinterpret_cast(rawPtr + sizeof(size_t)); + } + + size_t allocSize = (std::max(currBuff.first * 2, reqSize) + 1) & ~0x01ull; + auto ptr = reinterpret_cast(malloc(allocSize)); + m_rawBuffers.emplace_back(allocSize, ptr); + m_rawOffset = 0; + + return GetRawBuffer(numGeos, rawPtr); +} + +template +void NMS_BuildCache::Release(geo_vec_ptr buff) +{ + m_cache.push(std::move(buff)); +} + +#undef MODE_GEOMETRY +#undef MODE_CHILDREN +#undef DIM_X +#undef DIM_Y diff --git a/nemo-retriever-ocr/cpp/non_maximal_suppression/non_maximal_suppression.cpp b/nemo-retriever-ocr/cpp/non_maximal_suppression/non_maximal_suppression.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75399965c40b13dada14a831b851a20eeba594d4 --- /dev/null +++ b/nemo-retriever-ocr/cpp/non_maximal_suppression/non_maximal_suppression.cpp @@ -0,0 +1,390 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "non_maximal_suppression.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../geometry.h" +#include "../common.h" +#include "nms_kd_tree.h" + +using namespace std; +namespace ix = torch::indexing; + +typedef EmbedQuad_ EFQuad; + +nms_result_t quad_non_maximal_suppression_cpu_impl( + torch::Tensor tQuads, torch::Tensor tProbs, + float probThreshold, float iouThreshold, + int64_t kernelHeight, int64_t kernelWidth, + int64_t maxRegions, + bool verbose) +{ + tQuads = tQuads.to(torch::kFloat32).to(torch::kCPU, /*non_blocking=*/ true); + tProbs = tProbs.to(torch::kFloat32).to(torch::kCPU, /*non_blocking=*/ true); + + auto tStart = chrono::high_resolution_clock::now(); + + cudaDeviceSynchronize(); + + auto tData = chrono::high_resolution_clock::now(); + + if (maxRegions == 0) { + maxRegions = numeric_limits::max(); + } + + // B,H,W,4,2 + auto quadsAccess = tQuads.accessor(); + // B,H,W + auto probsAccess = tProbs.accessor(); + + + const int64_t B = probsAccess.size(0); + + vector> allQuads{ (unsigned int)B }; + vector> batchQuads{ (unsigned int)B }; + vector>> batchAdjIdxs{ (unsigned int)B }; + vector> batchVisited{ (unsigned int)B }; + + vector> batchKDTrees{ (unsigned int)B }; + + decltype(tData) tRowSpan, tBuildKD, tAdjacent; + + // Only enable parallelism if release mode + #ifndef NDEBUG + #pragma omp parallel num_threads (8) + #endif + { + // Step 1: Combine quads by row + // Parallelize on both batch and rows + #ifndef NDEBUG + #pragma omp for collapse (2) + #endif + for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) { + for (int64_t row = 0; row < probsAccess.size(1); ++row) { + vector& quads = batchQuads[batchIdx]; + EFQuad currQuad; + + auto commitQuad = [&]() { + if (currQuad.NumQuads > 0) { + #pragma omp critical + { + if (quads.size() < maxRegions) { + quads.push_back(currQuad); + } + } + currQuad.Reset(); + } + }; + + for (int64_t col = 0; col < probsAccess.size(2); ++col) { + Quad_ predQuad{ quadsAccess[batchIdx][row][col].data() }; + float predConf = probsAccess[batchIdx][row][col]; + + // If we're currently in a span, then merge + if (predConf >= probThreshold) { + auto iou = currQuad.NumQuads > 0 ? predQuad.IOU_UpperBound(currQuad) : 0; + + // These two regions aren't mergable. Finalize the current quad, and start a new one + if (iou < iouThreshold) { + commitQuad(); + } + + currQuad.Append(predQuad, predConf); + } + // Otherwise, commit it if valid + else { + commitQuad(); + } + } + + // Capture any dangling span + commitQuad(); + } + } + + #ifndef NDEBUG + #pragma omp single + #endif + { + tRowSpan = chrono::high_resolution_clock::now(); + } + + #ifndef NDEBUG + #pragma omp for + #endif + for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) { + batchKDTrees[batchIdx].Build(batchQuads[batchIdx]); + } + + static const int64_t TASK_SIZE = 2; + + // Step 2: At this point, we have the set of row-merged quads, so now we + // apply the real merge algorithm. For this, we start with an adjacency matrix. + // + // OMP note: "single" means that only one of the threads in the parallel group will execute this block. + // We're using tasking here to add a bunch of work to the thread pool that will be processed concurrently. + #ifndef NDEBUG + #pragma omp single + #endif + { + tBuildKD = chrono::high_resolution_clock::now(); + + for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) { + int64_t numQuads = batchQuads[batchIdx].size(); + batchAdjIdxs[batchIdx].resize(numQuads); + + for (int64_t q = 0; q < numQuads; q += TASK_SIZE) { + // This defines a task that will be executed in parallel by the pool + // OMP note: + // "shared" essentially means that we're capturing these variables by reference + // "firstprivate" means that we're capturing these variables by value + #ifndef NDEBUG + #pragma omp task default(none) shared(batchAdjIdxs, batchQuads, batchKDTrees, batchVisited, iouThreshold) firstprivate(batchIdx, numQuads, q) + #endif + { + vector& quads = batchQuads[batchIdx]; + auto& kdTree = batchKDTrees[batchIdx]; + unordered_set& visited = batchVisited[batchIdx]; + + for (int64_t n = q, nend = min(numQuads, q + TASK_SIZE); n < nend; ++n) { + vector& adjIdxs = batchAdjIdxs[batchIdx][n]; + + kdTree.FindIntersections(n, + [n, iouThreshold, &quads, &visited, &adjIdxs](size_t m, float bdsPctN, float bdsPctM, float bdsIOU) { + float pctN, pctM, iou; + tie(pctN, pctM, iou) = geometry_region_sizes(quads[n], quads[m]); + + // Merge + if (iou >= iouThreshold) { + adjIdxs.push_back(m); + // The next two cases are when one region envelops the other. In this case, take the larger region. + // If iou > 0, then they overlap at least somewhat + } + else if (pctN > 0.8 || pctM > 0.8) { + float nHeight = quads[n].Height(); + float mHeight = quads[m].Height(); + + float ratio = nHeight > mHeight ? mHeight / nHeight : nHeight / mHeight; + // If the two quads are roughly the same height (within 90% of each other), then eliminate the smaller region + if (ratio > 0.9) { + if (pctN > 0.8) { + // M envelops N + #pragma omp critical + // Marking a node as visited will prevent it from being processed during the adjacency collapse phase + visited.insert(n); + } + else if (pctM > 0.8) { + // N envelops M + #pragma omp critical + visited.insert(m); + } + } + } + } + ); + } + } + } + } + + #ifndef NDEBUG + #pragma omp taskwait + #endif + + tAdjacent = chrono::high_resolution_clock::now(); + } + + #ifndef NDEBUG + #pragma omp for + #endif + for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) { + vector>& adjIdxs = batchAdjIdxs[batchIdx]; + vector& quads = batchQuads[batchIdx]; + vector& finalQuads = allQuads[batchIdx]; + unordered_set& visited = batchVisited[batchIdx]; + + // Step 3: Using depth first search, merge the regions + for (int64_t n = 0; n < quads.size(); ++n) { + EFQuad currQuad; + visit_node(quads, n, adjIdxs, currQuad, visited); + + if (currQuad.NumQuads > 0) { + currQuad.Prepare(); + + finalQuads.push_back(currQuad); + } + } + } + + } // End parallel + + auto tMerge = chrono::high_resolution_clock::now(); + + int64_t numOutQuads = 0; + for (int64_t batchIdx = 0; batchIdx < B; ++batchIdx) { + numOutQuads += allQuads[batchIdx].size(); + } + + // Allocate the output tensors in pinned memory because they will be immediately sent back to the GPU + auto pinnedOpt = torch::TensorOptions().pinned_memory(true); + + // Step 4: Convert the quads into tensor representation + auto outQuadTensor = torch::empty({ numOutQuads, 4, 2 }, pinnedOpt.dtype(torch::kFloat32)); + auto outConfTensor = torch::empty({ numOutQuads }, pinnedOpt.dtype(torch::kFloat32)); + auto outCountTensor = torch::empty({ (int)allQuads.size() }, pinnedOpt.dtype(torch::kInt64)); + + auto outQuadAccess = outQuadTensor.accessor(); + auto outConfAccess = outConfTensor.accessor(); + auto outCountAccess = outCountTensor.accessor(); + + int64_t offset = 0; + for (int64_t batchIdx = 0; batchIdx < allQuads.size(); ++batchIdx) { + vector& exQuads = allQuads[batchIdx]; + + outCountAccess[batchIdx] = exQuads.size(); + + for (int64_t qIdx = 0; qIdx < exQuads.size(); ++qIdx, ++offset) { + copy_quad(exQuads[qIdx], outQuadAccess[offset].data()); + outConfAccess[offset] = exQuads[qIdx].Confidence; + } + } + + if (verbose) { + auto tWrite = chrono::high_resolution_clock::now(); + + typedef chrono::duration tp_t; + tp_t dataElapsed = tData - tStart; + tp_t rowSpanElapsed = tRowSpan - tData; + tp_t buildKDElapsed = tBuildKD - tRowSpan; + tp_t adjacentElapsed = tAdjacent - tBuildKD; + tp_t mergeElapsed = tMerge - tAdjacent; + tp_t writeElapsed = tWrite - tMerge; + tp_t totalElapsed = tWrite - tStart; + + // print_tensor(outCountTensor); + cout << "NMS " << numOutQuads + << " - Wait for data: " << dataElapsed.count() << "ms" + << ", Row Span: " << rowSpanElapsed.count() << "ms" + << ", Build KD: " << buildKDElapsed.count() << "ms" + << ", Adjacency: " << adjacentElapsed.count() << "ms" + << ", Merge: " << mergeElapsed.count() << "ms" + << ", Write: " << writeElapsed.count() << "ms" + << ", Total: " << totalElapsed.count() << "ms" + << endl; + } + + return { outQuadTensor, outConfTensor, outCountTensor }; +} + +nms_result_t quad_non_maximal_suppression( + torch::Tensor tQuads, torch::Tensor tProbs, + float probThreshold, float iouThreshold, + int64_t kernelHeight, int64_t kernelWidth, + int64_t maxRegions, + bool verbose) +{ + auto nmsFn = tQuads.is_cuda() ? + cuda_quad_non_maximal_suppression : + quad_non_maximal_suppression_cpu_impl; + + torch::Tensor quads, confidence, regionCounts; + tie(quads, confidence, regionCounts) = nmsFn( + tQuads, tProbs, + probThreshold, iouThreshold, + kernelHeight, kernelWidth, + maxRegions, verbose + ); + +#ifndef NDEBUG + // In debug mode, do cell sorting so that it's easier to see where the quads are + auto cells = get<0>(quads.min(1)).div_(10).floor_(); + auto maxX = cells.index({ ix::Slice(), 0 }).max(); + + cells = maxX * cells.select(1, 1) + cells.select(1, 0); + + // Ensure that we keep them ordered by example + auto regionIdxs = torch::arange(regionCounts.size(0), cells.options()).repeat_interleave(regionCounts); + auto cellMax = cells.max(); + cells += cellMax * regionIdxs; + + auto order = torch::argsort(cells); + + quads = quads.index({ order }); + confidence = confidence.index({ order }); +#endif + + return { quads, confidence, regionCounts }; +} + +vector reduced_quad_non_maximal_suppression( + const vector &rowQuads, float iouThreshold, int64_t imageHeight, int64_t imageWidth) +{ + // auto tStart = chrono::high_resolution_clock::now(); + + vector allQuads; + + TEFQuad currQuad; + + auto commitQuad = [&] () { + if (currQuad.NumQuads > 0) { + allQuads.push_back(move(currQuad)); + } + currQuad.Reset(); + }; + + for (const auto &thisQuad : rowQuads) { + auto iou = currQuad.NumQuads > 0 ? thisQuad.IOU_UpperBound(currQuad) : 0; + + // These two regions aren't mergeable. Finalize the current quad, and start a new one + if (iou < iouThreshold) { + commitQuad(); + } + + currQuad.Append(thisQuad, 1); + } + + // Capture any dangling span + commitQuad(); + + const int64_t numQuads = allQuads.size(); + vector mergeQuads; + vector visited; + visited.resize(numQuads, false); + + NMS_KDTree kdTree; + kdTree.Build(allQuads); + + for (int64_t row = 0; row < numQuads; ++row) { + if (visited[row]) continue; + + TEFQuad &rowQuad = allQuads[row]; + + kdTree.FindIntersections(row, + [row, iouThreshold, &rowQuad, &allQuads, &visited] (size_t col, float pctN, float pctM, float iou) { + if (iou >= iouThreshold && ! visited[col]) { + rowQuad.Append(move(allQuads[col])); + visited[col] = true; + } + } + ); + + mergeQuads.push_back(move(rowQuad)); + } + + // auto tEnd = chrono::high_resolution_clock::now(); + + // chrono::duration totalElapsed = tEnd - tStart; + // cout << "Row NMS " << mergeQuads.size() << " - Time: " << totalElapsed.count() << "ms" << endl; + + return mergeQuads; +} diff --git a/nemo-retriever-ocr/cpp/non_maximal_suppression/non_maximal_suppression.h b/nemo-retriever-ocr/cpp/non_maximal_suppression/non_maximal_suppression.h new file mode 100644 index 0000000000000000000000000000000000000000..a5dff7786c2920483c59572b7f3729836218d151 --- /dev/null +++ b/nemo-retriever-ocr/cpp/non_maximal_suppression/non_maximal_suppression.h @@ -0,0 +1,102 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "nms_common.h" +#include "../geometry.h" + +/* +* \brief Result type for non-maximal suppression. +* +* The results are flattened across the batch, use the third value (region counts) to determine which +* example a quad is associated with. +* +* N - Total number of quads for the entire batch +* B - Batch size +* +* 0 - quads - Nx4x2 +* 1 - confidence - N +* 2 - regionCounts - B (s.t. sum(regionCounts) == N) +*/ +typedef std::tuple nms_result_t; + +nms_result_t quad_non_maximal_suppression( + torch::Tensor quads, torch::Tensor probs, + float probThreshold, float iouThreshold, + int64_t kernelHeight, int64_t kernelWidth, + int64_t maxRegions, + bool verbose = false); + + +template +struct TrackedInPlaceQuad_ : InPlaceQuad_ { + Point_ ImgCoords; + + TrackedInPlaceQuad_(Point_ imgCoords) : ImgCoords(std::move(imgCoords)) {} + TrackedInPlaceQuad_(int64_t row, int64_t col) : ImgCoords(col, row) {} +}; + +template +struct TrackedEmbedQuad_ : EmbedQuad_ { + std::vector> ImgCoords; + + TrackedEmbedQuad_(T confidence = 0): EmbedQuad_(confidence) {} + TrackedEmbedQuad_(const TrackedEmbedQuad_ &other) = default; + + void swap(TrackedEmbedQuad_ &other) noexcept { + using std::swap; + + swap(ImgCoords, other.ImgCoords); + + EmbedQuad_::swap(other); + } + + TrackedEmbedQuad_(TrackedEmbedQuad_ &&other) : TrackedEmbedQuad_() { + other.swap(*this); + } + + TrackedEmbedQuad_ &operator=(TrackedEmbedQuad_ other) { + other.swap(*this); + return *this; + } + + void Append(const TrackedInPlaceQuad_ &q, T conf, T numQuads = 1) { + ImgCoords.push_back(q.ImgCoords); + + EmbedQuad_::Append(q, conf, numQuads); + } + + void Append(const TrackedEmbedQuad_ &other) { + ImgCoords.insert(end(ImgCoords), begin(other.ImgCoords), end(other.ImgCoords)); + + EmbedQuad_::Append(other); + } + + void Reset() { + ImgCoords.clear(); + + EmbedQuad_::Reset(); + } +}; + +typedef TrackedInPlaceQuad_ TIPQuad; +typedef TrackedEmbedQuad_ TEFQuad; + + +std::vector reduced_quad_non_maximal_suppression( + const std::vector &rowQuads, float iouThreshold, int64_t imageHeight, int64_t imageWidth); + +std::vector quad_non_maximal_suppression_backward( + torch::Tensor quads, torch::Tensor probs, + torch::Tensor gradOutQuads, torch::Tensor gradOutProbs); + +nms_result_t cuda_quad_non_maximal_suppression( + torch::Tensor quads, torch::Tensor probs, + float probThreshold, float iouThreshold, + int64_t kernelHeight, int64_t kernelWidth, + int64_t maxRegions, + bool verbose); diff --git a/nemo-retriever-ocr/cpp/non_maximal_suppression/strided_quad.h b/nemo-retriever-ocr/cpp/non_maximal_suppression/strided_quad.h new file mode 100644 index 0000000000000000000000000000000000000000..3ee13685223c0f51b8cb41bc7f8e1e87ec4de598 --- /dev/null +++ b/nemo-retriever-ocr/cpp/non_maximal_suppression/strided_quad.h @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../geometry.h" +#include "../cuda_intellisense.cuh" + + +template +struct StridedQuad_ : QuadBase_> +{ + T *DataPtr = nullptr; + int64_t Stride = 0; + + StridedQuad_() = default; + __host__ __device__ + StridedQuad_(T *dataPtr, int64_t stride) + : DataPtr(dataPtr), Stride(stride) {} + + __host__ __device__ + const Point_ operator[](int64_t offset) const { + auto ptOffset = DataPtr + 2 * offset * Stride; + return { + *ptOffset, + ptOffset[Stride] + }; + } + + __host__ __device__ + InPlaceQuad_ ToIPQuad() const + { + InPlaceQuad_ ret; + #pragma unroll + for (int64_t i = 0; i < 4; ++i) { + ret.Vertices[i] = (*this)[i]; + } + return ret; + } +}; + +template +struct StridedEmbedQuad_ : StridedQuad_ +{ + using StridedQuad_::StridedQuad_; + + __host__ __device__ T &Confidence() { return this->DataPtr[8 * this->Stride]; } + __host__ __device__ const T Confidence() const { return this->DataPtr[8 * this->Stride]; } + + __host__ __device__ T &NumQuads() { return this->DataPtr[9 * this->Stride]; } + __host__ __device__ const T NumQuads() const { return this->DataPtr[9 * this->Stride]; } +}; diff --git a/nemo-retriever-ocr/cpp/promote.h b/nemo-retriever-ocr/cpp/promote.h new file mode 100644 index 0000000000000000000000000000000000000000..ea9d28707a94070ca55d7d7cd509ec23f47feb88 --- /dev/null +++ b/nemo-retriever-ocr/cpp/promote.h @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +template +struct Promote { + /* + * The "Promote" object can be used to discover a computation floating type that + * is both efficient on hardward, and also doesn't result in a loss of precision. + * + * Examples: + * Promote::type == float + * Promote::type == double + * Promote::type == float + * + * Additionally, the promote structure can be used to discover the best computation + * type when given heterogeneous input types. + * + * Examples: + * Promote::type == double + * Promote::type == float + */ + typedef float type; +}; + +template<> +struct Promote { typedef double type; }; +template<> +struct Promote { typedef double type; }; +template +struct Promote { typedef double type; }; +template +struct Promote { typedef double type; }; + +template +struct Promote { typedef typename Promote::type>::type type; }; diff --git a/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify.h b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify.h new file mode 100644 index 0000000000000000000000000000000000000000..2f5d1142d25cdb9ebfc2ee813bc7ab412ea3e367 --- /dev/null +++ b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify.h @@ -0,0 +1,98 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "quad_rectify_cpu.h" +#include "quad_rectify_gpu.h" + +inline +torch::Tensor quad_rectify_calc_quad_width(torch::Tensor quads, + int64_t outputHeight, + int64_t roundFactor, + float maxWidth) +{ + if (quads.dim() < 2 || quads.dim() > 3) { + throw std::runtime_error("Invalid quads dimensions."); + } + + if (quads.size(-1) != 2 || quads.size(-2) != 4) { + throw std::runtime_error("The final 2 quad dimensions must be 4x2."); + } + + if (quads.dim() == 2) { + quads = quads.unsqueeze(0); + } + + if (quads.is_cuda()) { + return quad_rectify_gpu_calc_quad_width(quads, outputHeight, roundFactor, maxWidth); + } else { + return quad_rectify_cpu_calc_quad_width(quads, outputHeight, roundFactor, maxWidth); + } +} + +inline +torch::Tensor quad_rectify_forward(torch::Tensor quads, + int64_t imageHeight, + int64_t imageWidth, + int64_t outputHeight, + int64_t outputWidth, + bool isotropic) +{ + if (quads.dim() < 2 || quads.dim() > 3) { + throw std::runtime_error("Invalid quads dimensions."); + } + + if (quads.size(-1) != 2 || quads.size(-2) != 4) { + throw std::runtime_error("The final 2 quad dimensions must be 4x2."); + } + + bool flatten = false; + if (quads.dim() == 2) { + quads = quads.unsqueeze(0); + flatten = true; + } + + torch::Tensor ret; + if (quads.is_cuda()) { + ret = quad_rectify_gpu_forward(quads, imageHeight, imageWidth, outputHeight, outputWidth, isotropic); + } + else { + ret = quad_rectify_cpu_forward(quads, imageHeight, imageWidth, outputHeight, outputWidth, isotropic); + } + + if (flatten) { + ret = ret[0]; + } + + return ret; +} + +inline +torch::Tensor quad_rectify_backward(torch::Tensor quads, torch::Tensor gradOutput, + int64_t imageHeight, int64_t imageWidth, + bool isotropic) +{ + if (quads.is_cuda() != gradOutput.is_cuda()) { + throw std::runtime_error("Either both 'quads' and 'gradOutput' must be cuda, or neither."); + } + + if (quads.dim() != 3 || quads.size(-2) != 4 || quads.size(-1) != 2) { + throw std::runtime_error("Expected quads to be 3 dimensional. Nx4x2."); + } + + if (gradOutput.dim() != 4 || + gradOutput.size(3) != 2) { + throw std::runtime_error("Expected 'gradOutput' to be 4d: Nxxx2."); + } + + if (quads.is_cuda()) { + return quad_rectify_gpu_backward(quads, gradOutput, imageHeight, imageWidth, isotropic); + } + else { + return quad_rectify_cpu_backward(quads, gradOutput, imageHeight, imageWidth, isotropic); + } +} diff --git a/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_cpu.cpp b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb258d40366c8b008081bb618d70a1018798fc4d --- /dev/null +++ b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_cpu.cpp @@ -0,0 +1,182 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + + +#include "quad_rectify_cpu.h" + +#include + +#include "../geometry.h" +#include "quad_rectify_shared.h" + +using namespace std; + +template +void quad_rectify_calc_quad_width_impl(const quads_accessor_t &quads, + output_accessor_t output, + const scalar_t outputHeight, + const scalar_t roundFactor, + const scalar_t maxWidth) +{ + const int64_t numQuads = quads.size(0); + + for (int64_t quadIdx = 0; quadIdx < numQuads; ++quadIdx) { + auto quadWidth = calc_quad_width(quads[quadIdx], outputHeight, roundFactor, maxWidth); + + output[quadIdx] = Convert::LeftToRight(quadWidth); + } +} + +template +void quad_rectify_cpu_forward_impl(const quads_accessor_t &quads, + output_accessor_t output, + const scalar_t imageHeight, + const scalar_t imageWidth, + bool isotropic) +{ + typedef Point_ Point_t; + + const int64_t numQuads = quads.size(0); + const int64_t outputHeight = output.size(1); + const int64_t outputWidth = output.size(2); + + for (int64_t quadIdx = 0; quadIdx < numQuads; ++quadIdx) { + auto currQuad = quads[quadIdx]; + + scalar_t quadWidth = isotropic ? calc_quad_width(currQuad, outputHeight, 1, outputWidth) : scalar_t(outputWidth); + + for (int64_t row = 0; row < outputHeight; ++row) { + for (int64_t col = 0; col < outputWidth; ++col) { + Point_t outputPoint = calc_rect_value(currQuad, + quadWidth, + outputHeight, + col, + row, + imageWidth, + imageHeight); + + auto currOutput = output[quadIdx][row][col]; + currOutput[0] = outputPoint.X; + currOutput[1] = outputPoint.Y; + } + } + } +} + +/*template +void quad_rectify_cpu_backward_impl(torch::Tensor quads, + torch::Tensor gradOutput, + torch::Tensor gradInput) +{ + const int64_t batchSize = gradOutput.size(0); + const int64_t outputHeight = gradOutput.size(1); + const int64_t outputWidth = gradOutput.size(2); + + auto gradInputAccess = gradInput.accessor(); + auto gradOutputAccess = gradOutput.accessor(); + + for (int64_t batchIdx = 0; batchIdx < batchSize; ++batchIdx) { + auto batchInputAccess = gradInputAccess[batchIdx]; + auto batchOutputAccess = gradOutputAccess[batchIdx]; + + for (int64_t rowIdx = 0; rowIdx < outputHeight; ++rowIdx) { + for (int64_t colIdx = 0; colIdx < outputWidth; ++colIdx) { + + const scalar_t fRow = scalar_t(rowIdx) / outputHeight; + const scalar_t fCol = scalar_t(colIdx) / outputWidth; + const scalar_t fRowCol = fRow * fCol; + + for (int64_t dim = 0; dim < 2; ++dim) { + const scalar_t dOut = batchOutputAccess[rowIdx][colIdx][dim]; + + const scalar_t gradIns[] = { + dOut * (fRowCol - fCol - fRow + 1), + dOut * (fCol - fRowCol), + dOut * fRowCol, + dOut * (fRow - fRowCol) + }; + + for (int64_t quadIdx = 0; quadIdx < 4; ++quadIdx) { + batchInputAccess[quadIdx][dim] += 2.0f * gradIns[quadIdx]; + } + } + } + } + } +}*/ + +torch::Tensor quad_rectify_cpu_calc_quad_width(torch::Tensor quads, + int64_t outputHeight, + int64_t roundFactor, + float maxWidth) +{ + auto output = torch::empty({ quads.size(0) }, + quads.options().dtype(torch::kInt64)); + + AT_DISPATCH_FLOATING_TYPES( + quads.scalar_type(), + "quad_rectify_cpu_calc_quad_width", + ([&] { + quad_rectify_calc_quad_width_impl( + quads.accessor(), + output.accessor(), + Convert::RightToLeft(outputHeight), + Convert::RightToLeft(roundFactor), + Convert::RightToLeft(maxWidth) + ); + }) + ); + + return output; +} + +torch::Tensor quad_rectify_cpu_forward(torch::Tensor quads, + int64_t imageHeight, + int64_t imageWidth, + int64_t outputHeight, + int64_t outputWidth, + bool isotropic) +{ + auto output = torch::empty({ quads.size(0), outputHeight, outputWidth, 2 }, + quads.options()); + + AT_DISPATCH_FLOATING_TYPES( + quads.scalar_type(), + "quad_rectify_cpu_forward", + ([&] { + quad_rectify_cpu_forward_impl( + quads.accessor(), + output.accessor(), + Convert::RightToLeft(imageHeight), + Convert::RightToLeft(imageWidth), + isotropic + ); + }) + ); + + return output; +} + +torch::Tensor quad_rectify_cpu_backward(torch::Tensor quads, + torch::Tensor gradOutput, + int64_t imageHeight, + int64_t imageWidth, + bool isotropic) +{ + auto gradInput = torch::zeros_like(quads); + + throw std::runtime_error("Calling backward, and it's not implemented!"); + + /*AT_DISPATCH_FLOATING_TYPES_AND_HALF( + quads.scalar_type(), + "quad_rectify_cpu_backward", + ([&] { + quad_rectify_cpu_backward_impl(quads, + gradOutput, + gradInput); + }) + );*/ + + return gradInput; +} diff --git a/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_cpu.h b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..c0ac9684b6bdcc189424f5820625d412868aebbb --- /dev/null +++ b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_cpu.h @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + + +#include + +torch::Tensor quad_rectify_cpu_calc_quad_width(torch::Tensor quads, + int64_t outputHeight, + int64_t roundFactor, + float maxWidth); + +torch::Tensor quad_rectify_cpu_forward(torch::Tensor quads, + int64_t imageHeight, + int64_t imageWidth, + int64_t outputHeight, + int64_t outputWidth, + bool isotropic); + +torch::Tensor quad_rectify_cpu_backward(torch::Tensor quads, + torch::Tensor gradOutput, + int64_t imageHeight, + int64_t imageWidth, + bool isotropic); diff --git a/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_gpu.cu b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..6b56fb35b9b88194afd03423c62aa439d1aefa75 --- /dev/null +++ b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_gpu.cu @@ -0,0 +1,289 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + + +#include "quad_rectify_gpu.h" + +#include +#include +#include + +#include "quad_rectify_shared.h" +#include "../half_ops.cuh" +#include "../geometry.h" + +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x) + +template +__global__ void quad_rectify_device_calc_quad_width(quads_accessor_t quads, + output_accessor_t output, + const scalar_t outputHeight, + const scalar_t roundFactor, + const scalar_t maxWidth) +{ + const unsigned int quadIdx = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int numQuads = quads.size(0); + + if (quadIdx >= numQuads) { + return; + } + + auto currQuad = quads[quadIdx]; + + auto quadWidth = calc_quad_width(currQuad, outputHeight, roundFactor, maxWidth); + + output[quadIdx] = Convert::LeftToRight(quadWidth); +} + +template +__global__ void quad_rectify_device_forward(quads_accessor_t quads, + output_accessor_t outputs, + const scalar_t imageHeight, + const scalar_t imageWidth, + bool isotropic) +{ + typedef Point_ Point_t; + + const unsigned int quadIdx = blockIdx.y * blockDim.y + threadIdx.y; + const unsigned int numQuads = quads.size(0); + + if (quadIdx >= numQuads) { + return; + } + + const unsigned int outputHeight = outputs.size(1); + const unsigned int outputWidth = outputs.size(2); + + const unsigned int offset = blockIdx.x * blockDim.x + threadIdx.x; + + const unsigned int x = offset % outputWidth; + const unsigned int y = offset / outputWidth; + + if (y >= outputHeight) { + return; + } + + auto quad = quads[quadIdx]; + auto output = outputs[quadIdx][y][x]; + + auto scOutputHeight = Convert::RightToLeft(outputHeight); + auto scOutputWidth = Convert::RightToLeft(outputWidth); + auto scOne = Convert::RightToLeft(1); + + scalar_t quadWidth = isotropic ? calc_quad_width(quad, scOutputHeight, scOne, scOutputWidth) : scOutputWidth; + + Point_t outputPoint = calc_rect_value(quad, + quadWidth, + scOutputHeight, + x, + y, + imageWidth, + imageHeight); + + output[0] = outputPoint.X; + output[1] = outputPoint.Y; +} + +template +__global__ void quad_rectify_device_backward(quads_accessor_t quads, + output_accessor_t gradOutput, + quads_accessor_t gradInput, + const scalar_t imageHeight, + const scalar_t imageWidth, + bool isotropic) +{ + const unsigned int numQuads = quads.size(0); + int64_t quadIdx = blockIdx.y * blockDim.y + threadIdx.y; + + int64_t offset = blockIdx.x * blockDim.x + threadIdx.x; + + const int64_t outputHeight = gradOutput.size(1); + const int64_t outputWidth = gradOutput.size(2); + + int64_t x = offset % outputWidth; + int64_t y = offset / outputWidth; + + auto scOutputHeight = Convert::RightToLeft(outputHeight); + auto scOutputWidth = Convert::RightToLeft(outputWidth); + auto scOne = Convert::RightToLeft(1); + const scalar_t scHalf = Convert::RightToLeft(0.5); + + auto currQuad = quads[quadIdx]; + scalar_t quadWidth = isotropic ? calc_quad_width(currQuad, scOutputHeight, scOne, scOutputWidth) : scOutputWidth; + + __shared__ scalar_t sharedFloats[32][8]; + + scalar_t scale[2] = { Convert::RightToLeft(2.0f) / imageWidth, + Convert::RightToLeft(2.0f) / imageHeight }; + + bool valid = false; + if (quadIdx < numQuads && y < outputHeight) { + auto fRow = (scalar_t(y) + scHalf) / outputHeight; + auto fCol = (scalar_t(x) + scHalf) / quadWidth; + // auto fRow = scalar_t(y) / (outputHeight - scOne); + // auto fCol = scalar_t(x) / (quadWidth - scOne); + auto fRowCol = fRow * fCol; + + if (fCol <= 1) { + #pragma unroll 2 + for (int64_t i = 0; i < 2; ++i) { + auto currGradOutput = gradOutput[quadIdx][y][x][i] * scale[i]; + + sharedFloats[threadIdx.x][0 + i] = currGradOutput * (fRowCol - fCol - fRow + 1); + sharedFloats[threadIdx.x][2 + i] = currGradOutput * (fCol - fRowCol); + sharedFloats[threadIdx.x][4 + i] = currGradOutput * fRowCol; + sharedFloats[threadIdx.x][6 + i] = currGradOutput * (fRow - fRowCol); + } + valid = true; + } + } + + if (! valid) { + #pragma unroll 8 + for (int64_t i = 0; i < 8; ++i) { + sharedFloats[threadIdx.x][i] = 0; + } + } + + __syncthreads(); + + // Now accumulate over the shared memory + for (unsigned int i = 16; i > 0; i /= 2) { + if (threadIdx.x < i) { + #pragma unroll 8 + for (unsigned int k = 0; k < 8; ++k) { + sharedFloats[threadIdx.x][k] += sharedFloats[threadIdx.x + i][k]; + } + } + __syncthreads(); + } + + auto pGradInput = gradInput[quadIdx].data(); + + // Finally, write the values + if (threadIdx.x == 0) { + #pragma unroll 8 + for (int64_t i = 0; i < 8; ++i) { + atomicAdd(pGradInput + i, sharedFloats[0][i]); + } + } +} + +torch::Tensor quad_rectify_gpu_calc_quad_width(torch::Tensor quads, + int64_t outputHeight, + int64_t roundFactor, + float maxWidth) +{ + CHECK_INPUT(quads); + + const int64_t numQuads = quads.size(0); + + dim3 dimBlock(32); + dim3 dimGrid(div_up(numQuads, dimBlock.x)); + + auto output = torch::empty({ numQuads }, + quads.options().dtype(torch::kInt64)); + + if (numQuads > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + quads.scalar_type(), + "quad_rectify_calc_quad_width", + ([&] { + typedef typename remap_half::type T; + quad_rectify_device_calc_quad_width KERNEL_ARG2(dimGrid, dimBlock) ( + quads.packed_accessor64(), + output.packed_accessor64(), + Convert::RightToLeft(outputHeight), + Convert::RightToLeft(roundFactor), + Convert::RightToLeft(maxWidth) + ); + }) + ); + } + + return output; +} + +torch::Tensor quad_rectify_gpu_forward(torch::Tensor quads, + int64_t imageHeight, + int64_t imageWidth, + int64_t outputHeight, + int64_t outputWidth, + bool isotropic) +{ + CHECK_INPUT(quads); + + const int64_t numQuads = quads.size(0); + const int64_t numCells = outputHeight * outputWidth; + + dim3 dimBlock(32); + dim3 dimGrid(div_up(numCells, dimBlock.x), + numQuads); + + auto output = torch::empty({ numQuads, outputHeight, outputWidth, 2 }, + quads.options()); + + if (numQuads > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + quads.scalar_type(), + "quad_rectify_device_forward", + ([&] { + typedef typename remap_half::type T; + quad_rectify_device_forward KERNEL_ARG2(dimGrid, dimBlock) ( + quads.packed_accessor64(), + output.packed_accessor64(), + Convert::RightToLeft(imageHeight), + Convert::RightToLeft(imageWidth), + isotropic + ); + }) + ); + } + + return output; +} + +torch::Tensor quad_rectify_gpu_backward(torch::Tensor quads, + torch::Tensor gradOutput, + int64_t imageHeight, + int64_t imageWidth, + bool isotropic) +{ + CHECK_INPUT(quads); + CHECK_INPUT(gradOutput); + + const int64_t numQuads = quads.size(0); + const int64_t outputHeight = gradOutput.size(1); + const int64_t outputWidth = gradOutput.size(2); + + const int64_t numCells = outputHeight * outputWidth; + + dim3 dimBlock(32); + dim3 dimGrid(div_up(numCells, dimBlock.x), + numQuads); + + auto gradInput = torch::zeros_like(quads); + + if (numQuads > 0) { + AT_DISPATCH_FLOATING_TYPES( + quads.scalar_type(), + "quad_rectify_device_backward", + ([&] { + typedef typename remap_half::type T; + quad_rectify_device_backward KERNEL_ARG2(dimGrid, dimBlock) ( + quads.packed_accessor64(), + gradOutput.packed_accessor64(), + gradInput.packed_accessor64(), + Convert::RightToLeft(imageHeight), + Convert::RightToLeft(imageWidth), + isotropic + ); + }) + ); + } + + return gradInput; +} diff --git a/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_gpu.h b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..2e8e832ada0c918dfc28ad4f106142c9c32755f8 --- /dev/null +++ b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_gpu.h @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + + +#include + +torch::Tensor quad_rectify_gpu_calc_quad_width(torch::Tensor quads, + int64_t outputHeight, + int64_t roundFactor, + float maxWidth); + +torch::Tensor quad_rectify_gpu_forward(torch::Tensor quads, + int64_t imageHeight, + int64_t imageWidth, + int64_t outputHeight, + int64_t outputWidth, + bool isotropic); + +torch::Tensor quad_rectify_gpu_backward(torch::Tensor quads, + torch::Tensor gradOutput, + int64_t imageHeight, + int64_t imageWidth, + bool isotropic); diff --git a/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_shared.h b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_shared.h new file mode 100644 index 0000000000000000000000000000000000000000..ef528edc1598ee382440dfba2a2abd37dcf1583b --- /dev/null +++ b/nemo-retriever-ocr/cpp/quad_rectify/quad_rectify_shared.h @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include "../half_ops.cuh" +#include "../geometry.h" + +// template +// __qr_inline__ __qr_device__ auto dot(const Point_t &a, const Point_t &b) -> decltype(a.X) { +// auto sq = a * b; +// return sq.X + sq.Y; +// } + +// template +// __qr_inline__ __qr_device__ auto dot(const Point_t &a) -> decltype(a.X) { +// return dot(a, a); +// } + +template +__qr_inline__ __qr_device__ scalar_t square(scalar_t v) { + return v * v; +} + +template +__qr_inline__ __qr_device__ auto sub_accessor(const accessor_t &a, const accessor_t &b) -> Point_::type> { + return { a[0] - b[0], a[1] - b[1] }; +} + +template +__qr_device__ scalar_t calc_quad_width(const accessor_t &quad, + scalar_t outputHeight, + scalar_t roundFactor, + scalar_t maxWidth) +{ + using std::max; + using std::ceil; + using std::floor; + typedef Point_ Point_t; + + Point_t vecWidth = sub_accessor(quad[1], quad[0]); + Point_t vecHeight = sub_accessor(quad[3], quad[0]); + Point_t vecHeight2 = sub_accessor(quad[2], quad[1]); + + scalar_t quadWidth = sqrt(dot(vecWidth)); + scalar_t quadHeight = sqrt(dot(vecHeight)); + scalar_t quadHeight2 = sqrt(dot(vecHeight2)); + + const scalar_t sc2 = Convert::To(2); + quadHeight = (quadHeight + quadHeight2) / sc2; + + if (quadHeight < sc2) { + quadHeight = sc2; + } + + scalar_t growthRatio = outputHeight / quadHeight; + quadWidth = growthRatio * quadWidth; + + quadWidth = max(roundFactor, ceil(quadWidth / roundFactor) * roundFactor); + + if (maxWidth > Convert::To(0) && quadWidth > maxWidth) { + quadWidth = maxWidth; + } + + return max(sc2, floor(quadWidth)); +} + +template +__qr_inline__ +__qr_device__ Point_ calc_rect_value(const accessor_t &quad, + const scalar_t quadWidth, + const scalar_t outputHeight, + const unsigned int x, + const unsigned int y, + const scalar_t imageWidth, + const scalar_t imageHeight) +{ + typedef Point_ Point_t; + + const Point_t pts[4] = { + quad[0], quad[1], quad[2], quad[3] + }; + + + const scalar_t scX = Convert::RightToLeft(x); + const scalar_t sc1 = Convert::RightToLeft(1); + const scalar_t scHalf = Convert::RightToLeft(0.5); + + const scalar_t fRow = (Convert::RightToLeft(y) + scHalf) / outputHeight; + const scalar_t fCol = (scX + scHalf) / quadWidth; + + // const scalar_t fRow = Convert::RightToLeft(y) / (outputHeight - sc1); + // const scalar_t fCol = scX / (quadWidth - sc1); + + Point_t outputPoint; + if (scX < quadWidth) { + const Point_t &q0 = pts[0]; + const Point_t A = pts[1] - q0; + const Point_t B = pts[3] - q0; + const Point_t C = pts[2] - pts[1]; + + outputPoint = q0 + + fCol * A + + fRow * B + + (fCol * fRow) * (C - B); + } + else { + outputPoint = { -sc1, -sc1 }; + } + + outputPoint /= Point_t{ imageWidth, imageHeight }; + + // Remap from [0, 1] -> [-1, 1] + outputPoint = (Convert::RightToLeft(2) * outputPoint) - sc1; + + return outputPoint; +} diff --git a/nemo-retriever-ocr/cpp/scope_timer.h b/nemo-retriever-ocr/cpp/scope_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..c3f80803e4b1d8654999e727075972d46c1b1474 --- /dev/null +++ b/nemo-retriever-ocr/cpp/scope_timer.h @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +namespace chrono = std::chrono; + + +class CudaStoreTimer +{ + typedef decltype(chrono::high_resolution_clock::now()) tp_t; +public: + CudaStoreTimer(double &storage, bool enabled=true, bool synchronize=true) + : m_storage(&storage), m_enabled(enabled), m_synchronize(synchronize) + { + m_start = GetTP(); + } + ~CudaStoreTimer() + { + if (! m_enabled) return; + + auto tNow = GetTP(); + chrono::duration dur = tNow - m_start; + *m_storage = dur.count(); + } + +private: + tp_t GetTP() const + { + if (m_enabled && m_synchronize) { + cudaDeviceSynchronize(); + } + return chrono::high_resolution_clock::now(); + } + + double *m_storage; + tp_t m_start; + bool m_enabled; + bool m_synchronize; +}; diff --git a/nemo-retriever-ocr/cpp/sparse_select/sparse_select.cpp b/nemo-retriever-ocr/cpp/sparse_select/sparse_select.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1d8ca52649ab24d87bb0e8a816cf8195472847a --- /dev/null +++ b/nemo-retriever-ocr/cpp/sparse_select/sparse_select.cpp @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "sparse_select.h" + +#include + +#include "../common.h" + +using namespace std; + +std::tuple> sparse_select(torch::Tensor sparseCounts, + const std::vector sparseTensors, + torch::Tensor selectIndices) +{ + bool is_gpu = sparseCounts.is_cuda(); + + auto sparseCountsCPU = sparseCounts.cpu(); + + auto sortedSelect = get<0>(torch::sort(selectIndices)); + + vector retTensors; + for (const torch::Tensor &t : sparseTensors) { + retTensors.push_back(t.index({sortedSelect})); + } + + vector offsets(1 + sparseCountsCPU.size(0)); + + auto sparseCtAccess = sparseCountsCPU.accessor(); + + for (int64_t i = 0; i < sparseCountsCPU.size(0); ++i) { + offsets[i + 1] = sparseCtAccess[i] + offsets[i]; + } + + // cout << "Offsets: " << offsets << endl; + + auto retCounts = torch::zeros_like(sparseCountsCPU); + + auto retCtAccess = retCounts.accessor(); + auto idxAccess = sortedSelect.accessor(); + + for (int64_t i = 0; i < idxAccess.size(0); ++i) { + int64_t idx = idxAccess[i]; + + int64_t batchIdx = std::upper_bound(begin(offsets), end(offsets), idx) - begin(offsets) - 1; + + // cout << "Index: " << idx << ", Batch Index: " << batchIdx << endl; + + retCtAccess[batchIdx] += 1; + } + + if (is_gpu) { + retCounts = retCounts.to(sparseCounts); + } + + return make_tuple(retCounts, retTensors); +} diff --git a/nemo-retriever-ocr/cpp/sparse_select/sparse_select.h b/nemo-retriever-ocr/cpp/sparse_select/sparse_select.h new file mode 100644 index 0000000000000000000000000000000000000000..4dc5f59568c6b964ce4eeed9d6436786e0c833b8 --- /dev/null +++ b/nemo-retriever-ocr/cpp/sparse_select/sparse_select.h @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +std::tuple> sparse_select(torch::Tensor sparseCounts, + const std::vector sparseTensors, + torch::Tensor selectIndices); diff --git a/nemo-retriever-ocr/cpp/text_region_grouping/dense_relations_to_graph.cpp b/nemo-retriever-ocr/cpp/text_region_grouping/dense_relations_to_graph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6780de057500fc5d826635cee312c3b00925fc17 --- /dev/null +++ b/nemo-retriever-ocr/cpp/text_region_grouping/dense_relations_to_graph.cpp @@ -0,0 +1,253 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "text_region_grouping.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +PhraseList rel_list_to_phrases(const relations_list_t &relList) +{ + PhraseList ret; + ret.reserve(relList.size()); + + for (const text_line_t &line : relList) { + TextLine tl; + tl.reserve(line.size()); + + for (const auto &rel : line) { + tl.push_back(get<0>(rel)); + } + + ret.push_back({ move(tl) }); + } + + return ret; +} + +template +relations_list_t rel_chain_to_groups(const rel_to_2_from_map_t &inChain, int64_t numRegions, const T *inProbs); + +template +relations_list_t dense_relations_to_graph_impl(torch::Tensor relationsTensor) +{ + if (relationsTensor.size(0) == 0) { + return relations_list_t{}; + } + + if (relationsTensor.size(0) != relationsTensor.size(1)) { + throw std::runtime_error("The relations tensor must be a square matrix!"); + } + + // Each row `i` of `relationsTensor` is a probability distribution of going from word `i` to word `k` + // If we find the maximum confidence into each word `k`, it tells us the strongest connection + // from `i` to `k`. + // So, `maxRelTensor` tells us the connection strength of the strongest connection coming into word `k`, + // and `fromIdxTensor` tells us the index of word `i` that has this connection + auto relations = relationsTensor.accessor(); + + const int64_t numRegions = relationsTensor.size(0); + torch::Tensor fromIdxsTensor = torch::full({ numRegions }, -1, torch::kInt64); + torch::Tensor fromProbsTensor = torch::zeros({ numRegions }, relationsTensor.options()); + + // Use `data_ptr` here because these tensors are 1-dimensional contiguous arrays, which saves us + // a multiply+add for each access + auto fromIdxs = fromIdxsTensor.data_ptr(); + auto fromProbs = fromProbsTensor.data_ptr(); + + for (int64_t fromIdx = 0; fromIdx < numRegions; ++fromIdx) { + auto fromRel = relations[fromIdx]; + + for (int64_t toIdx = 0; toIdx < numRegions; ++toIdx) { + auto relProb = fromRel[toIdx]; + + if (relProb >= 0.5) { + T &maxProb = fromProbs[toIdx]; + if (fromIdxs[toIdx] == -1 || relProb > maxProb) { + fromIdxs[toIdx] = fromIdx; + maxProb = relProb; + } + // Because each row sums to 1, it's only possible for <= 1 columns to have + // a value above 0.5 + break; + } + } + } + + return rel_chain_to_groups(fromIdxs, numRegions, fromProbs); +} + +relations_list_t dense_relations_to_graph_with_probs(torch::Tensor relationsTensor) +{ + relations_list_t ret; + AT_DISPATCH_FLOATING_TYPES( + relationsTensor.scalar_type(), + "dense_relations_to_graph", + ([&] { + ret = dense_relations_to_graph_impl(relationsTensor); + }) + ); + return ret; +} + +PhraseList dense_relations_to_graph(torch::Tensor relations) +{ + return rel_list_to_phrases(dense_relations_to_graph_with_probs(relations)); +} + +template +relations_list_t sparse_relations_to_graph_impl(torch::Tensor relationsTensor, torch::Tensor neighborIdxsTensor) +{ + if (relationsTensor.size(0) == 0) { + return relations_list_t{}; + } + + auto maxRelsTensor = torch::zeros({ relationsTensor.size(0) }, relationsTensor.options()); + auto fromIdxsTensor = torch::full({ relationsTensor.size(0) }, -1, torch::kInt64); + + auto relations = relationsTensor.accessor(); + auto neighborIdxs = neighborIdxsTensor.accessor(); + auto maxRels = maxRelsTensor.data_ptr(); + auto fromIdxs = fromIdxsTensor.data_ptr(); + + const int64_t N = relationsTensor.size(0); + const int64_t K = relationsTensor.size(1); + + // Refer to `dense_relations_to_graph` for the reasoning behind this. The only difference here + // is the indirection due to sparsity. At the completion of this double loop, + // `maxRelsTensor` and `fromIdxTensor` are of identical form to the dense case. + for (int64_t fromIdx = 0; fromIdx < N; ++fromIdx) { + auto fromNeighborIdxs = neighborIdxs[fromIdx].data(); + auto fromRelations = relations[fromIdx].data(); + + // Skip the null column + for (int64_t c = 1; c < K; ++c) { + // All of these values will be offset by +1 to account for the null column + int64_t toIdx = fromNeighborIdxs[c] - 1; + // The relations tensor already has the null column stripped off + T toProb = fromRelations[c]; + + if (toProb > 0.5f) { + T &bestProb = maxRels[toIdx]; + if (toProb > bestProb) { + bestProb = toProb; + fromIdxs[toIdx] = fromIdx; + } + // Due to the softmax, only one value could ever be >0.5, if any, + // so if we've encountered this value, then we're done with this `fromIdx` + break; + } + } + } + + return rel_chain_to_groups(fromIdxs, N, maxRels); +} + +relations_list_t sparse_relations_to_graph(torch::Tensor relationsTensor, torch::Tensor neighborIdxs) +{ + relations_list_t ret; + + AT_DISPATCH_FLOATING_TYPES( + relationsTensor.scalar_type(), + "sparse_relations_to_graph", + ([&] { + ret = sparse_relations_to_graph_impl(relationsTensor, neighborIdxs); + }) + ); + + return ret; +} + +template +relations_list_t rel_chain_to_groups(const rel_to_2_from_map_t &inChain, const int64_t numRegions, const T *inProbs) +{ + // inChain is a vector over the relations that tells us, for a given position `i`, + // the strongest relation `k` leading into that, if any, otherwise -1. + // So if `inChain[5] == 2`, this means that region `k==2` connects to region `i==5`. + // It's also mandatory that the elements in inChain != -1 form a bijection + // between from/to (e.g. the same from index can't be used twice) + + // Create a mapping that goes from word `fromIdx` to word `toIdx`, which is the + // reverse mapping of inChain + auto outChainTensor = torch::full({ numRegions }, -1, torch::kInt64); + auto outChain = outChainTensor.data_ptr(); + + auto outProbsTensor = torch::ones({ numRegions }, torch::kFloat); + auto outProbs = outProbsTensor.data_ptr(); + + for (int64_t toIdx = 0; toIdx < numRegions; ++toIdx) { + int64_t fromIdx = inChain[toIdx]; + if (fromIdx != -1) { + outChain[fromIdx] = toIdx; + outProbs[fromIdx] = static_cast(inProbs[toIdx]); + } + } + + std::vector processed; processed.resize(numRegions, false); + + text_line_t currChain; currChain.reserve(32); + relations_list_t groups; + + for (int64_t toIdx = 0; toIdx < numRegions; ++toIdx) { + int64_t fromIdx = inChain[toIdx]; + + if (fromIdx == -1 || processed[toIdx]) { + continue; + } + + processed[toIdx] = true; + currChain.clear(); + currChain.emplace_back(toIdx, outProbs[fromIdx]); + + int64_t currIdx = toIdx; + while (true) { + fromIdx = inChain[currIdx]; + // The second check ensures that we don't encounter any cycles + if (fromIdx == -1 || processed[fromIdx]) { + break; + } + + processed[fromIdx] = true; + currChain.emplace_back(fromIdx, outProbs[fromIdx]); + currIdx = fromIdx; + } + + // At this point, `currChain` contains all of the indices from `toIdx` (index 0) backward. + // So, we can initialize the group with the reverse iterator to the current chain + text_line_t group{ std::rbegin(currChain), std::rend(currChain) }; + + // However, we also need to harvest all of the indices from `toIdx` forward + int64_t nextIdx = toIdx; + while (true) { + int64_t nextToIdx = outChain[nextIdx]; + // Same as before, second check will break cycles + if (nextToIdx == -1 || processed[nextToIdx]) { + break; + } + + processed[nextToIdx] = true; + group.emplace_back(nextToIdx, static_cast(inProbs[nextToIdx])); + nextIdx = nextToIdx; + } + + groups.push_back(move(group)); + } + + // Now add in the stragglers + for (int64_t wIdx = 0; wIdx < numRegions; ++wIdx) { + if (! processed[wIdx]) { + groups.push_back({ { wIdx, 1.0f } }); + } + } + + return groups; +} diff --git a/nemo-retriever-ocr/cpp/text_region_grouping/text_region_grouping.cpp b/nemo-retriever-ocr/cpp/text_region_grouping/text_region_grouping.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b97dd4af72a30f48b0bbf2dc14868fb7d889ef3 --- /dev/null +++ b/nemo-retriever-ocr/cpp/text_region_grouping/text_region_grouping.cpp @@ -0,0 +1,368 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "text_region_grouping.h" + +#include +#include +#include +#include +#include +#include + +#include "../geometry.h" +#include "../common.h" +#include "../scope_timer.h" +#include "../non_maximal_suppression/nms_kd_tree.h" + +using namespace std; + + +vector> relations_to_clusters(const unordered_map &lineRelations, int64_t numQuads) +{ + unordered_map reverseLookup; + for (auto &kv : lineRelations) { + reverseLookup.emplace(kv.second, kv.first); + } + + vector ret; + + unordered_set visited; + for (auto &kv : lineRelations) { + int64_t root = kv.first; + if (visited.count(root)) continue; + + // Find the root + bool bad = false; + auto rlIter = reverseLookup.find(root); + while (rlIter != reverseLookup.end()) { + root = rlIter->second; + rlIter = reverseLookup.find(root); + if (visited.count(root)) { + bad = true; + break; + } + visited.insert(root); + } + + // It could be bad either because this node was already visited, or if there's a cycle in the graph (somehow) + if (bad) continue; + + // Now walk the chain + TextLine line; + auto iter = lineRelations.end(); + do + { + line.push_back(root); + visited.insert(root); + iter = lineRelations.find(root); + if (iter != lineRelations.end()) { + root = iter->second; + } + } while (iter != lineRelations.end()); + + ret.push_back(move(line)); + } + + // Add in all of the stragglers + for (int64_t i = 0; i < numQuads; ++i) { + if (! visited.count(i)) { + TextLine line; + line.push_back(i); + + ret.push_back(move(line)); + } + } + + return ret; +} + +template +inline T default_match(const Quad_ &a, const Quad_ &query, const Quad_ &b) +{ + return std::max(intersection_area(query, b), 0); +} + +template +inline T height_match(const Quad_ &a, const Quad_ &query, const Quad_ &b) +{ + T aHeight = a.Height(); + T bHeight = b.Height(); + + T ratio = aHeight / bHeight; + if (ratio > 1) { + ratio = 1 / ratio; + } + + // Don't combine words that have very different heights + if (ratio < 0.5) { + return 0; + } + + T dfMatch = default_match(a, query, b); + return dfMatch * ratio; +} + +template +vector> cluster_quads(const vector> &vQuads, CtorFn queryConstructor, MatchFn matchFn) +{ + torch::Tensor tAllIxAreas = torch::zeros({ (int)vQuads.size(), (int)vQuads.size() }, torch::kFloat32); + auto accAllIxAreas = tAllIxAreas.accessor(); + + NMS_KDTree> kdTree; + kdTree.Build(vQuads); + + for (int64_t i = 0; i < vQuads.size(); ++i) { + for (int64_t direction = 0; direction < 2; ++direction) { + auto queryPts = queryConstructor(i, direction); + Quad_ queryQuad{ queryPts.data() }; + + kdTree.FindIntersections(queryQuad, + [i, &accAllIxAreas, &vQuads, &queryQuad, &matchFn, direction] + (int64_t k, float pctN, float pctM, float bdsIOU) + { + if (i == k) return; + + auto oI = i, oK = k; + if (direction == 1) { + swap(oI, oK); + } + + float matchVal = matchFn(vQuads[oI], queryQuad, vQuads[oK]); + accAllIxAreas[oI][oK] = max(accAllIxAreas[oI][oK], matchVal); + } + ); + } + } + + torch::Tensor tAllIxIdxs; + tie(tAllIxAreas, tAllIxIdxs) = torch::sort(tAllIxAreas, /*dim=*/1, /*descending=*/true); + + accAllIxAreas = tAllIxAreas.accessor(); + auto accAllIxIdxs = tAllIxIdxs.accessor(); + + stack> idxsToProcess; + for (int64_t i = 0; i < vQuads.size(); ++i) { + idxsToProcess.emplace(i, 0); + } + + unordered_map> ownerLookup; + + while (! idxsToProcess.empty()) { + int64_t i, k; + tie(i, k) = idxsToProcess.top(); + idxsToProcess.pop(); + + for (; k < vQuads.size(); ++k) { + T ixArea = accAllIxAreas[i][k]; + + // There will never be a better match, so just stop processing this quad + if (ixArea == 0) break; + + int64_t oIdx = accAllIxIdxs[i][k]; + auto ownerIter = ownerLookup.find(oIdx); + // There is no owner for this region yet! + if (ownerIter == ownerLookup.end()) { + ownerLookup.emplace(oIdx, make_tuple(i, ixArea, k)); + break; + } else { + int64_t exI, exK; + T exIxArea; + tie(exI, exIxArea, exK) = ownerIter->second; + + // This quad is a better match, so boot the other one and add it to the stack + if (ixArea > exIxArea) { + ownerIter->second = make_tuple(i, ixArea, k); + // Increment the counter for the quad we just booted + idxsToProcess.emplace(exI, exK + 1); + break; + } + + // Otherwise, move to the next best match + } + } + } + + unordered_map bijection; + for (auto &kv : ownerLookup) { + bijection.emplace(get<0>(kv.second), kv.first); + } + + return relations_to_clusters(bijection, vQuads.size()); +} + +template +vector quads_to_lines(const vector> &vQuads, T horizontalTolerance) +{ + auto queryCtor = [&] (int64_t i, int64_t direction) { + const Quad_ &currQuad = vQuads[i]; + + // Direction == 0: Box to the right of the word + // Direction == 1: Box to the left of the word + + Point_ d1 = currQuad[1] - currQuad[0]; + Point_ d2 = currQuad[2] - currQuad[3]; + Point_ dEnd = direction == 0 ? (currQuad[2] - currQuad[1]) : (currQuad[3] - currQuad[0]); + + T w1 = length(d1); + T w2 = length(d2); + T endHeight = length(dEnd); + T width = (w1 + w2) / 2; + + d1 /= w1; + d2 /= w2; + dEnd /= endHeight; + + T avgCharWidth = std::max(endHeight * 0.75f, 1.0f); + + Point_ endPt = direction == 0 ? currQuad[1] : currQuad[0]; + + Point_ rp0 = endPt + (T(0.1) * endHeight * dEnd); + Point_ rp1 = endPt + (T(0.9) * endHeight * dEnd); + + if (direction == 1) { + d1 *= -1.0f; + d2 *= -1.0f; + } + + Point_ qp1 = rp0 + (avgCharWidth * horizontalTolerance * d1); + Point_ qp2 = rp1 + (avgCharWidth * horizontalTolerance * d2); + + if (direction == 0) { + // Create an extension of this quad outward horizontally + array, 4> pts{ rp0, qp1, qp2, rp1 }; + + return pts; + } else { + array, 4> pts{ qp1, rp0, rp1, qp2 }; + + return pts; + } + }; + + return cluster_quads(vQuads, queryCtor, height_match); +} + +template +PhraseList lines_to_phrases(const vector> &vQuads, const vector &lines, + T verticalTolerance) +{ + vector, 4>> linesPts; + for (const TextLine &line : lines) { + const Quad_ &leftQuad = vQuads[line.front()]; + const Quad_ &rightQuad = vQuads[line.back()]; + + linesPts.push_back({leftQuad[0], rightQuad[1], rightQuad[2], leftQuad[3]}); + } + + vector> vLines; + for (auto &line : linesPts) { + vLines.emplace_back(line.data()); + } + + auto queryCtor = [&] (int64_t i, int64_t direction) { + const Quad_ &currQuad = vLines[i]; + + Point_ d1 = currQuad[3] - currQuad[0]; + Point_ d2 = currQuad[2] - currQuad[1]; + + if (direction == 0) { + Point_ qp1 = currQuad[3] + (verticalTolerance * d1); + Point_ qp2 = currQuad[2] + (verticalTolerance * d2); + + array, 4> pts{ currQuad[3], currQuad[2], qp2, qp1 }; + + return pts; + } else { + Point_ qp1 = currQuad[0] - (verticalTolerance * d1); + Point_ qp2 = currQuad[1] - (verticalTolerance * d2); + + array, 4> pts{ qp1, qp2, currQuad[1], currQuad[0] }; + + return pts; + } + }; + + vector> phraseClusters = cluster_quads(vLines, queryCtor, height_match); + + PhraseList phrases; + for (const vector &lineIdxs : phraseClusters) { + Phrase phrase; + for (int64_t lineIdx : lineIdxs) { + phrase.push_back(lines[lineIdx]); + } + phrases.push_back(move(phrase)); + } + + return phrases; +} + + +template +PhraseList process_image(torch::Tensor quads, + T horizontalTolerance, T verticalTolerance, bool verbose) +{ + static bool s_timerEnabled = true; + + if (verbose) { + cout << "Text Grouper - Processing Image..." << endl; + } + + auto quadsAccess = quads.accessor(); + + vector> vQuads; + for (int64_t i = 0; i < quadsAccess.size(0); ++i) { + vQuads.emplace_back(quadsAccess[i].data()); + } + + double tQuadsToLines, tLinesToPhrases; + vector lines; + PhraseList phrases; + + { + // Step 1: Construct Lines + CudaStoreTimer t(tQuadsToLines, s_timerEnabled && verbose, false); + lines = quads_to_lines(vQuads, horizontalTolerance); + } + + { + // Step 2: Construct the phrases + CudaStoreTimer t(tLinesToPhrases, s_timerEnabled && verbose, false); + phrases = lines_to_phrases(vQuads, lines, verticalTolerance); + } + + if (s_timerEnabled && verbose) { + cout << "Text Grouper " << quads.size(0) + << " - To Lines: " << tQuadsToLines << "ms" + << ", To Phrases: " << tLinesToPhrases << "ms" + << endl; + } + + return phrases; +} + + +std::vector text_region_grouping(torch::Tensor sparseQuads, torch::Tensor sparseCounts, + float horizontalTolerance, + float verticalTolerance, + bool verbose) +{ + sparseQuads = sparseQuads.to(torch::kFloat32); + sparseCounts = sparseCounts.to(torch::kInt64); + + auto countsAccess = sparseCounts.accessor(); + + vector ret; + + int64_t offset = 0, ct = 0; + for (int64_t i = 0; i < countsAccess.size(0); ++i, offset += ct) { + ct = countsAccess[i]; + + auto currQuads = sparseQuads.slice(0, offset, offset + ct); + + ret.push_back(process_image(currQuads, horizontalTolerance, verticalTolerance, verbose)); + } + + return ret; +} diff --git a/nemo-retriever-ocr/cpp/text_region_grouping/text_region_grouping.h b/nemo-retriever-ocr/cpp/text_region_grouping/text_region_grouping.h new file mode 100644 index 0000000000000000000000000000000000000000..87fea500e23740df3a42e3e5d060f75f29a7cad5 --- /dev/null +++ b/nemo-retriever-ocr/cpp/text_region_grouping/text_region_grouping.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + + +typedef std::vector TextLine; +typedef std::vector Phrase; +typedef std::vector PhraseList; + + +std::vector text_region_grouping(torch::Tensor sparseQuads, torch::Tensor sparseCounts, + float horizontalTolerance = 2.0f, + float verticalTolerance = 1.0f, + bool verbose = false); + +PhraseList dense_relations_to_graph(torch::Tensor relations); + +typedef std::tuple relation_t; +typedef std::vector text_line_t; +typedef std::vector relations_list_t; + +relations_list_t dense_relations_to_graph_with_probs(torch::Tensor relationsTensor); +relations_list_t sparse_relations_to_graph(torch::Tensor relationsTensor, torch::Tensor neighborIdxs); diff --git a/nemo-retriever-ocr/cpp/third_party/clipper/clipper.cpp b/nemo-retriever-ocr/cpp/third_party/clipper/clipper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22cc0dffe055c88cccac99e518866082a2d2f875 --- /dev/null +++ b/nemo-retriever-ocr/cpp/third_party/clipper/clipper.cpp @@ -0,0 +1,4623 @@ +/******************************************************************************* +* * +* Author : Angus Johnson * +* Version : 6.4.0 * +* Date : 2 July 2015 * +* Website : http://www.angusj.com * +* Copyright : Angus Johnson 2010-2015 * +* * +* License: * +* Use, modification & distribution is subject to Boost Software License Ver 1. * +* http://www.boost.org/LICENSE_1_0.txt * +* * +* Attributions: * +* The code in this library is an extension of Bala Vatti's clipping algorithm: * +* "A generic solution to polygon clipping" * +* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. * +* http://portal.acm.org/citation.cfm?id=129906 * +* * +* Computer graphics and geometric modeling: implementation and algorithms * +* By Max K. Agoston * +* Springer; 1 edition (January 4, 2005) * +* http://books.google.com/books?q=vatti+clipping+agoston * +* * +* See also: * +* "Polygon Offsetting by Computing Winding Numbers" * +* Paper no. DETC2005-85513 pp. 565-575 * +* ASME 2005 International Design Engineering Technical Conferences * +* and Computers and Information in Engineering Conference (IDETC/CIE2005) * +* September 24-28, 2005 , Long Beach, California, USA * +* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf * +* * +*******************************************************************************/ + +/******************************************************************************* +* * +* This is a translation of the Delphi Clipper library and the naming style * +* used has retained a Delphi flavour. * +* * +*******************************************************************************/ + +#include "clipper.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ClipperLib { + +static double const pi = 3.141592653589793238; +static double const two_pi = pi *2; +static double const def_arc_tolerance = 0.25; + +enum Direction { dRightToLeft, dLeftToRight }; + +static int const Unassigned = -1; //edge not currently 'owning' a solution +static int const Skip = -2; //edge that would otherwise close a path + +#define HORIZONTAL (-1.0E+40) +#define TOLERANCE (1.0e-20) +#define NEAR_ZERO(val) (((val) > -TOLERANCE) && ((val) < TOLERANCE)) + +struct TEdge { + IntPoint Bot; + IntPoint Curr; //current (updated for every new scanbeam) + IntPoint Top; + double Dx; + PolyType PolyTyp; + EdgeSide Side; //side only refers to current side of solution poly + int WindDelta; //1 or -1 depending on winding direction + int WindCnt; + int WindCnt2; //winding count of the opposite polytype + int OutIdx; + TEdge *Next; + TEdge *Prev; + TEdge *NextInLML; + TEdge *NextInAEL; + TEdge *PrevInAEL; + TEdge *NextInSEL; + TEdge *PrevInSEL; +}; + +struct IntersectNode { + TEdge *Edge1; + TEdge *Edge2; + IntPoint Pt; +}; + +struct LocalMinimum { + cInt Y; + TEdge *LeftBound; + TEdge *RightBound; +}; + +struct OutPt; + +//OutRec: contains a path in the clipping solution. Edges in the AEL will +//carry a pointer to an OutRec when they are part of the clipping solution. +struct OutRec { + int Idx; + bool IsHole; + bool IsOpen; + OutRec *FirstLeft; //see comments in clipper.pas + PolyNode *PolyNd; + OutPt *Pts; + OutPt *BottomPt; +}; + +struct OutPt { + int Idx; + IntPoint Pt; + OutPt *Next; + OutPt *Prev; +}; + +struct Join { + OutPt *OutPt1; + OutPt *OutPt2; + IntPoint OffPt; +}; + +struct LocMinSorter +{ + inline bool operator()(const LocalMinimum& locMin1, const LocalMinimum& locMin2) + { + return locMin2.Y < locMin1.Y; + } +}; + +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ + +inline cInt Round(double val) +{ + if ((val < 0)) return static_cast(val - 0.5); + else return static_cast(val + 0.5); +} +//------------------------------------------------------------------------------ + +inline cInt Abs(cInt val) +{ + return val < 0 ? -val : val; +} + +//------------------------------------------------------------------------------ +// PolyTree methods ... +//------------------------------------------------------------------------------ + +void PolyTree::Clear() +{ + for (PolyNodes::size_type i = 0; i < AllNodes.size(); ++i) + delete AllNodes[i]; + AllNodes.resize(0); + Childs.resize(0); +} +//------------------------------------------------------------------------------ + +PolyNode* PolyTree::GetFirst() const +{ + if (!Childs.empty()) + return Childs[0]; + else + return 0; +} +//------------------------------------------------------------------------------ + +int PolyTree::Total() const +{ + int result = (int)AllNodes.size(); + //with negative offsets, ignore the hidden outer polygon ... + if (result > 0 && Childs[0] != AllNodes[0]) result--; + return result; +} + +//------------------------------------------------------------------------------ +// PolyNode methods ... +//------------------------------------------------------------------------------ + +PolyNode::PolyNode(): Childs(), Parent(0), Index(0), m_IsOpen(false) +{ +} +//------------------------------------------------------------------------------ + +int PolyNode::ChildCount() const +{ + return (int)Childs.size(); +} +//------------------------------------------------------------------------------ + +void PolyNode::AddChild(PolyNode& child) +{ + unsigned cnt = (unsigned)Childs.size(); + Childs.push_back(&child); + child.Parent = this; + child.Index = cnt; +} +//------------------------------------------------------------------------------ + +PolyNode* PolyNode::GetNext() const +{ + if (!Childs.empty()) + return Childs[0]; + else + return GetNextSiblingUp(); +} +//------------------------------------------------------------------------------ + +PolyNode* PolyNode::GetNextSiblingUp() const +{ + if (!Parent) //protects against PolyTree.GetNextSiblingUp() + return 0; + else if (Index == Parent->Childs.size() - 1) + return Parent->GetNextSiblingUp(); + else + return Parent->Childs[Index + 1]; +} +//------------------------------------------------------------------------------ + +bool PolyNode::IsHole() const +{ + bool result = true; + PolyNode* node = Parent; + while (node) + { + result = !result; + node = node->Parent; + } + return result; +} +//------------------------------------------------------------------------------ + +bool PolyNode::IsOpen() const +{ + return m_IsOpen; +} +//------------------------------------------------------------------------------ + +#ifndef use_int32 + +//------------------------------------------------------------------------------ +// Int128 class (enables safe math on signed 64bit integers) +// eg Int128 val1((long64)9223372036854775807); //ie 2^63 -1 +// Int128 val2((long64)9223372036854775807); +// Int128 val3 = val1 * val2; +// val3.AsString => "85070591730234615847396907784232501249" (8.5e+37) +//------------------------------------------------------------------------------ + +class Int128 +{ + public: + ulong64 lo; + long64 hi; + + Int128(long64 _lo = 0) + { + lo = (ulong64)_lo; + if (_lo < 0) hi = -1; else hi = 0; + } + + + Int128(const Int128 &val): lo(val.lo), hi(val.hi){} + + Int128(const long64& _hi, const ulong64& _lo): lo(_lo), hi(_hi){} + + Int128& operator = (const long64 &val) + { + lo = (ulong64)val; + if (val < 0) hi = -1; else hi = 0; + return *this; + } + + bool operator == (const Int128 &val) const + {return (hi == val.hi && lo == val.lo);} + + bool operator != (const Int128 &val) const + { return !(*this == val);} + + bool operator > (const Int128 &val) const + { + if (hi != val.hi) + return hi > val.hi; + else + return lo > val.lo; + } + + bool operator < (const Int128 &val) const + { + if (hi != val.hi) + return hi < val.hi; + else + return lo < val.lo; + } + + bool operator >= (const Int128 &val) const + { return !(*this < val);} + + bool operator <= (const Int128 &val) const + { return !(*this > val);} + + Int128& operator += (const Int128 &rhs) + { + hi += rhs.hi; + lo += rhs.lo; + if (lo < rhs.lo) hi++; + return *this; + } + + Int128 operator + (const Int128 &rhs) const + { + Int128 result(*this); + result+= rhs; + return result; + } + + Int128& operator -= (const Int128 &rhs) + { + *this += -rhs; + return *this; + } + + Int128 operator - (const Int128 &rhs) const + { + Int128 result(*this); + result -= rhs; + return result; + } + + Int128 operator-() const //unary negation + { + if (lo == 0) + return Int128(-hi, 0); + else + return Int128(~hi, ~lo + 1); + } + + operator double() const + { + const double shift64 = 18446744073709551616.0; //2^64 + if (hi < 0) + { + if (lo == 0) return (double)hi * shift64; + else return -(double)(~lo + ~hi * shift64); + } + else + return (double)(lo + hi * shift64); + } + +}; +//------------------------------------------------------------------------------ + +Int128 Int128Mul (long64 lhs, long64 rhs) +{ + bool negate = (lhs < 0) != (rhs < 0); + + if (lhs < 0) lhs = -lhs; + ulong64 int1Hi = ulong64(lhs) >> 32; + ulong64 int1Lo = ulong64(lhs & 0xFFFFFFFF); + + if (rhs < 0) rhs = -rhs; + ulong64 int2Hi = ulong64(rhs) >> 32; + ulong64 int2Lo = ulong64(rhs & 0xFFFFFFFF); + + //nb: see comments in clipper.pas + ulong64 a = int1Hi * int2Hi; + ulong64 b = int1Lo * int2Lo; + ulong64 c = int1Hi * int2Lo + int1Lo * int2Hi; + + Int128 tmp; + tmp.hi = long64(a + (c >> 32)); + tmp.lo = long64(c << 32); + tmp.lo += long64(b); + if (tmp.lo < b) tmp.hi++; + if (negate) tmp = -tmp; + return tmp; +}; +#endif + +//------------------------------------------------------------------------------ +// Miscellaneous global functions +//------------------------------------------------------------------------------ + +bool Orientation(const Path &poly) +{ + return Area(poly) >= 0; +} +//------------------------------------------------------------------------------ + +double Area(const Path &poly) +{ + int size = (int)poly.size(); + if (size < 3) return 0; + + double a = 0; + for (int i = 0, j = size -1; i < size; ++i) + { + a += ((double)poly[j].X + poly[i].X) * ((double)poly[j].Y - poly[i].Y); + j = i; + } + return -a * 0.5; +} +//------------------------------------------------------------------------------ + +double Area(const OutPt *op) +{ + const OutPt *startOp = op; + if (!op) return 0; + double a = 0; + do { + a += (double)(op->Prev->Pt.X + op->Pt.X) * (double)(op->Prev->Pt.Y - op->Pt.Y); + op = op->Next; + } while (op != startOp); + return a * 0.5; +} +//------------------------------------------------------------------------------ + +double Area(const OutRec &outRec) +{ + return Area(outRec.Pts); +} +//------------------------------------------------------------------------------ + +bool PointIsVertex(const IntPoint &Pt, OutPt *pp) +{ + OutPt *pp2 = pp; + do + { + if (pp2->Pt == Pt) return true; + pp2 = pp2->Next; + } + while (pp2 != pp); + return false; +} +//------------------------------------------------------------------------------ + +//See "The Point in Polygon Problem for Arbitrary Polygons" by Hormann & Agathos +//http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.88.5498&rep=rep1&type=pdf +int PointInPolygon(const IntPoint &pt, const Path &path) +{ + //returns 0 if false, +1 if true, -1 if pt ON polygon boundary + int result = 0; + size_t cnt = path.size(); + if (cnt < 3) return 0; + IntPoint ip = path[0]; + for(size_t i = 1; i <= cnt; ++i) + { + IntPoint ipNext = (i == cnt ? path[0] : path[i]); + if (ipNext.Y == pt.Y) + { + if ((ipNext.X == pt.X) || (ip.Y == pt.Y && + ((ipNext.X > pt.X) == (ip.X < pt.X)))) return -1; + } + if ((ip.Y < pt.Y) != (ipNext.Y < pt.Y)) + { + if (ip.X >= pt.X) + { + if (ipNext.X > pt.X) result = 1 - result; + else + { + double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) - + (double)(ipNext.X - pt.X) * (ip.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (ipNext.Y > ip.Y)) result = 1 - result; + } + } else + { + if (ipNext.X > pt.X) + { + double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) - + (double)(ipNext.X - pt.X) * (ip.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (ipNext.Y > ip.Y)) result = 1 - result; + } + } + } + ip = ipNext; + } + return result; +} +//------------------------------------------------------------------------------ + +int PointInPolygon (const IntPoint &pt, OutPt *op) +{ + //returns 0 if false, +1 if true, -1 if pt ON polygon boundary + int result = 0; + OutPt* startOp = op; + for(;;) + { + if (op->Next->Pt.Y == pt.Y) + { + if ((op->Next->Pt.X == pt.X) || (op->Pt.Y == pt.Y && + ((op->Next->Pt.X > pt.X) == (op->Pt.X < pt.X)))) return -1; + } + if ((op->Pt.Y < pt.Y) != (op->Next->Pt.Y < pt.Y)) + { + if (op->Pt.X >= pt.X) + { + if (op->Next->Pt.X > pt.X) result = 1 - result; + else + { + double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) - + (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) result = 1 - result; + } + } else + { + if (op->Next->Pt.X > pt.X) + { + double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) - + (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) result = 1 - result; + } + } + } + op = op->Next; + if (startOp == op) break; + } + return result; +} +//------------------------------------------------------------------------------ + +bool Poly2ContainsPoly1(OutPt *OutPt1, OutPt *OutPt2) +{ + OutPt* op = OutPt1; + do + { + //nb: PointInPolygon returns 0 if false, +1 if true, -1 if pt on polygon + int res = PointInPolygon(op->Pt, OutPt2); + if (res >= 0) return res > 0; + op = op->Next; + } + while (op != OutPt1); + return true; +} +//---------------------------------------------------------------------- + +bool SlopesEqual(const TEdge &e1, const TEdge &e2, bool UseFullInt64Range) +{ +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(e1.Top.Y - e1.Bot.Y, e2.Top.X - e2.Bot.X) == + Int128Mul(e1.Top.X - e1.Bot.X, e2.Top.Y - e2.Bot.Y); + else +#endif + return (e1.Top.Y - e1.Bot.Y) * (e2.Top.X - e2.Bot.X) == + (e1.Top.X - e1.Bot.X) * (e2.Top.Y - e2.Bot.Y); +} +//------------------------------------------------------------------------------ + +bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, + const IntPoint pt3, bool UseFullInt64Range) +{ +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(pt1.Y-pt2.Y, pt2.X-pt3.X) == Int128Mul(pt1.X-pt2.X, pt2.Y-pt3.Y); + else +#endif + return (pt1.Y-pt2.Y)*(pt2.X-pt3.X) == (pt1.X-pt2.X)*(pt2.Y-pt3.Y); +} +//------------------------------------------------------------------------------ + +bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, + const IntPoint pt3, const IntPoint pt4, bool UseFullInt64Range) +{ +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(pt1.Y-pt2.Y, pt3.X-pt4.X) == Int128Mul(pt1.X-pt2.X, pt3.Y-pt4.Y); + else +#endif + return (pt1.Y-pt2.Y)*(pt3.X-pt4.X) == (pt1.X-pt2.X)*(pt3.Y-pt4.Y); +} +//------------------------------------------------------------------------------ + +inline bool IsHorizontal(TEdge &e) +{ + return e.Dx == HORIZONTAL; +} +//------------------------------------------------------------------------------ + +inline double GetDx(const IntPoint pt1, const IntPoint pt2) +{ + return (pt1.Y == pt2.Y) ? + HORIZONTAL : (double)(pt2.X - pt1.X) / (pt2.Y - pt1.Y); +} +//--------------------------------------------------------------------------- + +inline void SetDx(TEdge &e) +{ + cInt dy = (e.Top.Y - e.Bot.Y); + if (dy == 0) e.Dx = HORIZONTAL; + else e.Dx = (double)(e.Top.X - e.Bot.X) / dy; +} +//--------------------------------------------------------------------------- + +inline void SwapSides(TEdge &Edge1, TEdge &Edge2) +{ + EdgeSide Side = Edge1.Side; + Edge1.Side = Edge2.Side; + Edge2.Side = Side; +} +//------------------------------------------------------------------------------ + +inline void SwapPolyIndexes(TEdge &Edge1, TEdge &Edge2) +{ + int OutIdx = Edge1.OutIdx; + Edge1.OutIdx = Edge2.OutIdx; + Edge2.OutIdx = OutIdx; +} +//------------------------------------------------------------------------------ + +inline cInt TopX(TEdge &edge, const cInt currentY) +{ + return ( currentY == edge.Top.Y ) ? + edge.Top.X : edge.Bot.X + Round(edge.Dx *(currentY - edge.Bot.Y)); +} +//------------------------------------------------------------------------------ + +void IntersectPoint(TEdge &Edge1, TEdge &Edge2, IntPoint &ip) +{ +#ifdef use_xyz + ip.Z = 0; +#endif + + double b1, b2; + if (Edge1.Dx == Edge2.Dx) + { + ip.Y = Edge1.Curr.Y; + ip.X = TopX(Edge1, ip.Y); + return; + } + else if (Edge1.Dx == 0) + { + ip.X = Edge1.Bot.X; + if (IsHorizontal(Edge2)) + ip.Y = Edge2.Bot.Y; + else + { + b2 = Edge2.Bot.Y - (Edge2.Bot.X / Edge2.Dx); + ip.Y = Round(ip.X / Edge2.Dx + b2); + } + } + else if (Edge2.Dx == 0) + { + ip.X = Edge2.Bot.X; + if (IsHorizontal(Edge1)) + ip.Y = Edge1.Bot.Y; + else + { + b1 = Edge1.Bot.Y - (Edge1.Bot.X / Edge1.Dx); + ip.Y = Round(ip.X / Edge1.Dx + b1); + } + } + else + { + b1 = Edge1.Bot.X - Edge1.Bot.Y * Edge1.Dx; + b2 = Edge2.Bot.X - Edge2.Bot.Y * Edge2.Dx; + double q = (b2-b1) / (Edge1.Dx - Edge2.Dx); + ip.Y = Round(q); + if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx)) + ip.X = Round(Edge1.Dx * q + b1); + else + ip.X = Round(Edge2.Dx * q + b2); + } + + if (ip.Y < Edge1.Top.Y || ip.Y < Edge2.Top.Y) + { + if (Edge1.Top.Y > Edge2.Top.Y) + ip.Y = Edge1.Top.Y; + else + ip.Y = Edge2.Top.Y; + if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx)) + ip.X = TopX(Edge1, ip.Y); + else + ip.X = TopX(Edge2, ip.Y); + } + //finally, don't allow 'ip' to be BELOW curr.Y (ie bottom of scanbeam) ... + if (ip.Y > Edge1.Curr.Y) + { + ip.Y = Edge1.Curr.Y; + //use the more vertical edge to derive X ... + if (std::fabs(Edge1.Dx) > std::fabs(Edge2.Dx)) + ip.X = TopX(Edge2, ip.Y); else + ip.X = TopX(Edge1, ip.Y); + } +} +//------------------------------------------------------------------------------ + +void ReversePolyPtLinks(OutPt *pp) +{ + if (!pp) return; + OutPt *pp1, *pp2; + pp1 = pp; + do { + pp2 = pp1->Next; + pp1->Next = pp1->Prev; + pp1->Prev = pp2; + pp1 = pp2; + } while( pp1 != pp ); +} +//------------------------------------------------------------------------------ + +void DisposeOutPts(OutPt*& pp) +{ + if (pp == 0) return; + pp->Prev->Next = 0; + while( pp ) + { + OutPt *tmpPp = pp; + pp = pp->Next; + delete tmpPp; + } +} +//------------------------------------------------------------------------------ + +inline void InitEdge(TEdge* e, TEdge* eNext, TEdge* ePrev, const IntPoint& Pt) +{ + std::memset(e, 0, sizeof(TEdge)); + e->Next = eNext; + e->Prev = ePrev; + e->Curr = Pt; + e->OutIdx = Unassigned; +} +//------------------------------------------------------------------------------ + +void InitEdge2(TEdge& e, PolyType Pt) +{ + if (e.Curr.Y >= e.Next->Curr.Y) + { + e.Bot = e.Curr; + e.Top = e.Next->Curr; + } else + { + e.Top = e.Curr; + e.Bot = e.Next->Curr; + } + SetDx(e); + e.PolyTyp = Pt; +} +//------------------------------------------------------------------------------ + +TEdge* RemoveEdge(TEdge* e) +{ + //removes e from double_linked_list (but without removing from memory) + e->Prev->Next = e->Next; + e->Next->Prev = e->Prev; + TEdge* result = e->Next; + e->Prev = 0; //flag as removed (see ClipperBase.Clear) + return result; +} +//------------------------------------------------------------------------------ + +inline void ReverseHorizontal(TEdge &e) +{ + //swap horizontal edges' Top and Bottom x's so they follow the natural + //progression of the bounds - ie so their xbots will align with the + //adjoining lower edge. [Helpful in the ProcessHorizontal() method.] + std::swap(e.Top.X, e.Bot.X); +#ifdef use_xyz + std::swap(e.Top.Z, e.Bot.Z); +#endif +} +//------------------------------------------------------------------------------ + +void SwapPoints(IntPoint &pt1, IntPoint &pt2) +{ + IntPoint tmp = pt1; + pt1 = pt2; + pt2 = tmp; +} +//------------------------------------------------------------------------------ + +bool GetOverlapSegment(IntPoint pt1a, IntPoint pt1b, IntPoint pt2a, + IntPoint pt2b, IntPoint &pt1, IntPoint &pt2) +{ + //precondition: segments are Collinear. + if (Abs(pt1a.X - pt1b.X) > Abs(pt1a.Y - pt1b.Y)) + { + if (pt1a.X > pt1b.X) SwapPoints(pt1a, pt1b); + if (pt2a.X > pt2b.X) SwapPoints(pt2a, pt2b); + if (pt1a.X > pt2a.X) pt1 = pt1a; else pt1 = pt2a; + if (pt1b.X < pt2b.X) pt2 = pt1b; else pt2 = pt2b; + return pt1.X < pt2.X; + } else + { + if (pt1a.Y < pt1b.Y) SwapPoints(pt1a, pt1b); + if (pt2a.Y < pt2b.Y) SwapPoints(pt2a, pt2b); + if (pt1a.Y < pt2a.Y) pt1 = pt1a; else pt1 = pt2a; + if (pt1b.Y > pt2b.Y) pt2 = pt1b; else pt2 = pt2b; + return pt1.Y > pt2.Y; + } +} +//------------------------------------------------------------------------------ + +bool FirstIsBottomPt(const OutPt* btmPt1, const OutPt* btmPt2) +{ + OutPt *p = btmPt1->Prev; + while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) p = p->Prev; + double dx1p = std::fabs(GetDx(btmPt1->Pt, p->Pt)); + p = btmPt1->Next; + while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) p = p->Next; + double dx1n = std::fabs(GetDx(btmPt1->Pt, p->Pt)); + + p = btmPt2->Prev; + while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) p = p->Prev; + double dx2p = std::fabs(GetDx(btmPt2->Pt, p->Pt)); + p = btmPt2->Next; + while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) p = p->Next; + double dx2n = std::fabs(GetDx(btmPt2->Pt, p->Pt)); + + if (std::max(dx1p, dx1n) == std::max(dx2p, dx2n) && + std::min(dx1p, dx1n) == std::min(dx2p, dx2n)) + return Area(btmPt1) > 0; //if otherwise identical use orientation + else + return (dx1p >= dx2p && dx1p >= dx2n) || (dx1n >= dx2p && dx1n >= dx2n); +} +//------------------------------------------------------------------------------ + +OutPt* GetBottomPt(OutPt *pp) +{ + OutPt* dups = 0; + OutPt* p = pp->Next; + while (p != pp) + { + if (p->Pt.Y > pp->Pt.Y) + { + pp = p; + dups = 0; + } + else if (p->Pt.Y == pp->Pt.Y && p->Pt.X <= pp->Pt.X) + { + if (p->Pt.X < pp->Pt.X) + { + dups = 0; + pp = p; + } else + { + if (p->Next != pp && p->Prev != pp) dups = p; + } + } + p = p->Next; + } + if (dups) + { + //there appears to be at least 2 vertices at BottomPt so ... + while (dups != p) + { + if (!FirstIsBottomPt(p, dups)) pp = dups; + dups = dups->Next; + while (dups->Pt != pp->Pt) dups = dups->Next; + } + } + return pp; +} +//------------------------------------------------------------------------------ + +bool Pt2IsBetweenPt1AndPt3(const IntPoint pt1, + const IntPoint pt2, const IntPoint pt3) +{ + if ((pt1 == pt3) || (pt1 == pt2) || (pt3 == pt2)) + return false; + else if (pt1.X != pt3.X) + return (pt2.X > pt1.X) == (pt2.X < pt3.X); + else + return (pt2.Y > pt1.Y) == (pt2.Y < pt3.Y); +} +//------------------------------------------------------------------------------ + +bool HorzSegmentsOverlap(cInt seg1a, cInt seg1b, cInt seg2a, cInt seg2b) +{ + if (seg1a > seg1b) std::swap(seg1a, seg1b); + if (seg2a > seg2b) std::swap(seg2a, seg2b); + return (seg1a < seg2b) && (seg2a < seg1b); +} + +//------------------------------------------------------------------------------ +// ClipperBase class methods ... +//------------------------------------------------------------------------------ + +ClipperBase::ClipperBase() //constructor +{ + m_CurrentLM = m_MinimaList.begin(); //begin() == end() here + m_UseFullRange = false; +} +//------------------------------------------------------------------------------ + +ClipperBase::~ClipperBase() //destructor +{ + Clear(); +} +//------------------------------------------------------------------------------ + +void RangeTest(const IntPoint& Pt, bool& useFullRange) +{ + if (useFullRange) + { + if (Pt.X > hiRange || Pt.Y > hiRange || -Pt.X > hiRange || -Pt.Y > hiRange) + throw clipperException("Coordinate outside allowed range"); + } + else if (Pt.X > loRange|| Pt.Y > loRange || -Pt.X > loRange || -Pt.Y > loRange) + { + useFullRange = true; + RangeTest(Pt, useFullRange); + } +} +//------------------------------------------------------------------------------ + +TEdge* FindNextLocMin(TEdge* E) +{ + for (;;) + { + while (E->Bot != E->Prev->Bot || E->Curr == E->Top) E = E->Next; + if (!IsHorizontal(*E) && !IsHorizontal(*E->Prev)) break; + while (IsHorizontal(*E->Prev)) E = E->Prev; + TEdge* E2 = E; + while (IsHorizontal(*E)) E = E->Next; + if (E->Top.Y == E->Prev->Bot.Y) continue; //ie just an intermediate horz. + if (E2->Prev->Bot.X < E->Bot.X) E = E2; + break; + } + return E; +} +//------------------------------------------------------------------------------ + +TEdge* ClipperBase::ProcessBound(TEdge* E, bool NextIsForward) +{ + TEdge *Result = E; + TEdge *Horz = 0; + + if (E->OutIdx == Skip) + { + //if edges still remain in the current bound beyond the skip edge then + //create another LocMin and call ProcessBound once more + if (NextIsForward) + { + while (E->Top.Y == E->Next->Bot.Y) E = E->Next; + //don't include top horizontals when parsing a bound a second time, + //they will be contained in the opposite bound ... + while (E != Result && IsHorizontal(*E)) E = E->Prev; + } + else + { + while (E->Top.Y == E->Prev->Bot.Y) E = E->Prev; + while (E != Result && IsHorizontal(*E)) E = E->Next; + } + + if (E == Result) + { + if (NextIsForward) Result = E->Next; + else Result = E->Prev; + } + else + { + //there are more edges in the bound beyond result starting with E + if (NextIsForward) + E = Result->Next; + else + E = Result->Prev; + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + locMin.LeftBound = 0; + locMin.RightBound = E; + E->WindDelta = 0; + Result = ProcessBound(E, NextIsForward); + m_MinimaList.push_back(locMin); + } + return Result; + } + + TEdge *EStart; + + if (IsHorizontal(*E)) + { + //We need to be careful with open paths because this may not be a + //true local minima (ie E may be following a skip edge). + //Also, consecutive horz. edges may start heading left before going right. + if (NextIsForward) + EStart = E->Prev; + else + EStart = E->Next; + if (IsHorizontal(*EStart)) //ie an adjoining horizontal skip edge + { + if (EStart->Bot.X != E->Bot.X && EStart->Top.X != E->Bot.X) + ReverseHorizontal(*E); + } + else if (EStart->Bot.X != E->Bot.X) + ReverseHorizontal(*E); + } + + EStart = E; + if (NextIsForward) + { + while (Result->Top.Y == Result->Next->Bot.Y && Result->Next->OutIdx != Skip) + Result = Result->Next; + if (IsHorizontal(*Result) && Result->Next->OutIdx != Skip) + { + //nb: at the top of a bound, horizontals are added to the bound + //only when the preceding edge attaches to the horizontal's left vertex + //unless a Skip edge is encountered when that becomes the top divide + Horz = Result; + while (IsHorizontal(*Horz->Prev)) Horz = Horz->Prev; + if (Horz->Prev->Top.X > Result->Next->Top.X) Result = Horz->Prev; + } + while (E != Result) + { + E->NextInLML = E->Next; + if (IsHorizontal(*E) && E != EStart && + E->Bot.X != E->Prev->Top.X) ReverseHorizontal(*E); + E = E->Next; + } + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Prev->Top.X) + ReverseHorizontal(*E); + Result = Result->Next; //move to the edge just beyond current bound + } else + { + while (Result->Top.Y == Result->Prev->Bot.Y && Result->Prev->OutIdx != Skip) + Result = Result->Prev; + if (IsHorizontal(*Result) && Result->Prev->OutIdx != Skip) + { + Horz = Result; + while (IsHorizontal(*Horz->Next)) Horz = Horz->Next; + if (Horz->Next->Top.X == Result->Prev->Top.X || + Horz->Next->Top.X > Result->Prev->Top.X) Result = Horz->Next; + } + + while (E != Result) + { + E->NextInLML = E->Prev; + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X) + ReverseHorizontal(*E); + E = E->Prev; + } + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X) + ReverseHorizontal(*E); + Result = Result->Prev; //move to the edge just beyond current bound + } + + return Result; +} +//------------------------------------------------------------------------------ + +bool ClipperBase::AddPath(const Path &pg, PolyType PolyTyp, bool Closed) +{ +#ifdef use_lines + if (!Closed && PolyTyp == ptClip) + throw clipperException("AddPath: Open paths must be subject."); +#else + if (!Closed) + throw clipperException("AddPath: Open paths have been disabled."); +#endif + + int highI = (int)pg.size() -1; + if (Closed) while (highI > 0 && (pg[highI] == pg[0])) --highI; + while (highI > 0 && (pg[highI] == pg[highI -1])) --highI; + if ((Closed && highI < 2) || (!Closed && highI < 1)) return false; + + //create a new edge array ... + TEdge *edges = new TEdge [highI +1]; + + bool IsFlat = true; + //1. Basic (first) edge initialization ... + try + { + edges[1].Curr = pg[1]; + RangeTest(pg[0], m_UseFullRange); + RangeTest(pg[highI], m_UseFullRange); + InitEdge(&edges[0], &edges[1], &edges[highI], pg[0]); + InitEdge(&edges[highI], &edges[0], &edges[highI-1], pg[highI]); + for (int i = highI - 1; i >= 1; --i) + { + RangeTest(pg[i], m_UseFullRange); + InitEdge(&edges[i], &edges[i+1], &edges[i-1], pg[i]); + } + } + catch(...) + { + delete [] edges; + throw; //range test fails + } + TEdge *eStart = &edges[0]; + + //2. Remove duplicate vertices, and (when closed) collinear edges ... + TEdge *E = eStart, *eLoopStop = eStart; + for (;;) + { + //nb: allows matching start and end points when not Closed ... + if (E->Curr == E->Next->Curr && (Closed || E->Next != eStart)) + { + if (E == E->Next) break; + if (E == eStart) eStart = E->Next; + E = RemoveEdge(E); + eLoopStop = E; + continue; + } + if (E->Prev == E->Next) + break; //only two vertices + else if (Closed && + SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr, m_UseFullRange) && + (!m_PreserveCollinear || + !Pt2IsBetweenPt1AndPt3(E->Prev->Curr, E->Curr, E->Next->Curr))) + { + //Collinear edges are allowed for open paths but in closed paths + //the default is to merge adjacent collinear edges into a single edge. + //However, if the PreserveCollinear property is enabled, only overlapping + //collinear edges (ie spikes) will be removed from closed paths. + if (E == eStart) eStart = E->Next; + E = RemoveEdge(E); + E = E->Prev; + eLoopStop = E; + continue; + } + E = E->Next; + if ((E == eLoopStop) || (!Closed && E->Next == eStart)) break; + } + + if ((!Closed && (E == E->Next)) || (Closed && (E->Prev == E->Next))) + { + delete [] edges; + return false; + } + + if (!Closed) + { + m_HasOpenPaths = true; + eStart->Prev->OutIdx = Skip; + } + + //3. Do second stage of edge initialization ... + E = eStart; + do + { + InitEdge2(*E, PolyTyp); + E = E->Next; + if (IsFlat && E->Curr.Y != eStart->Curr.Y) IsFlat = false; + } + while (E != eStart); + + //4. Finally, add edge bounds to LocalMinima list ... + + //Totally flat paths must be handled differently when adding them + //to LocalMinima list to avoid endless loops etc ... + if (IsFlat) + { + if (Closed) + { + delete [] edges; + return false; + } + E->Prev->OutIdx = Skip; + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + locMin.LeftBound = 0; + locMin.RightBound = E; + locMin.RightBound->Side = esRight; + locMin.RightBound->WindDelta = 0; + for (;;) + { + if (E->Bot.X != E->Prev->Top.X) ReverseHorizontal(*E); + if (E->Next->OutIdx == Skip) break; + E->NextInLML = E->Next; + E = E->Next; + } + m_MinimaList.push_back(locMin); + m_edges.push_back(edges); + return true; + } + + m_edges.push_back(edges); + bool leftBoundIsForward; + TEdge* EMin = 0; + + //workaround to avoid an endless loop in the while loop below when + //open paths have matching start and end points ... + if (E->Prev->Bot == E->Prev->Top) E = E->Next; + + for (;;) + { + E = FindNextLocMin(E); + if (E == EMin) break; + else if (!EMin) EMin = E; + + //E and E.Prev now share a local minima (left aligned if horizontal). + //Compare their slopes to find which starts which bound ... + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + if (E->Dx < E->Prev->Dx) + { + locMin.LeftBound = E->Prev; + locMin.RightBound = E; + leftBoundIsForward = false; //Q.nextInLML = Q.prev + } else + { + locMin.LeftBound = E; + locMin.RightBound = E->Prev; + leftBoundIsForward = true; //Q.nextInLML = Q.next + } + + if (!Closed) locMin.LeftBound->WindDelta = 0; + else if (locMin.LeftBound->Next == locMin.RightBound) + locMin.LeftBound->WindDelta = -1; + else locMin.LeftBound->WindDelta = 1; + locMin.RightBound->WindDelta = -locMin.LeftBound->WindDelta; + + E = ProcessBound(locMin.LeftBound, leftBoundIsForward); + if (E->OutIdx == Skip) E = ProcessBound(E, leftBoundIsForward); + + TEdge* E2 = ProcessBound(locMin.RightBound, !leftBoundIsForward); + if (E2->OutIdx == Skip) E2 = ProcessBound(E2, !leftBoundIsForward); + + if (locMin.LeftBound->OutIdx == Skip) + locMin.LeftBound = 0; + else if (locMin.RightBound->OutIdx == Skip) + locMin.RightBound = 0; + m_MinimaList.push_back(locMin); + if (!leftBoundIsForward) E = E2; + } + return true; +} +//------------------------------------------------------------------------------ + +bool ClipperBase::AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed) +{ + bool result = false; + for (Paths::size_type i = 0; i < ppg.size(); ++i) + if (AddPath(ppg[i], PolyTyp, Closed)) result = true; + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::Clear() +{ + DisposeLocalMinimaList(); + for (EdgeList::size_type i = 0; i < m_edges.size(); ++i) + { + TEdge* edges = m_edges[i]; + delete [] edges; + } + m_edges.clear(); + m_UseFullRange = false; + m_HasOpenPaths = false; +} +//------------------------------------------------------------------------------ + +void ClipperBase::Reset() +{ + m_CurrentLM = m_MinimaList.begin(); + if (m_CurrentLM == m_MinimaList.end()) return; //ie nothing to process + std::sort(m_MinimaList.begin(), m_MinimaList.end(), LocMinSorter()); + + m_Scanbeam = ScanbeamList(); //clears/resets priority_queue + //reset all edges ... + for (MinimaList::iterator lm = m_MinimaList.begin(); lm != m_MinimaList.end(); ++lm) + { + InsertScanbeam(lm->Y); + TEdge* e = lm->LeftBound; + if (e) + { + e->Curr = e->Bot; + e->Side = esLeft; + e->OutIdx = Unassigned; + } + + e = lm->RightBound; + if (e) + { + e->Curr = e->Bot; + e->Side = esRight; + e->OutIdx = Unassigned; + } + } + m_ActiveEdges = 0; + m_CurrentLM = m_MinimaList.begin(); +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeLocalMinimaList() +{ + m_MinimaList.clear(); + m_CurrentLM = m_MinimaList.begin(); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::PopLocalMinima(cInt Y, const LocalMinimum *&locMin) +{ + if (m_CurrentLM == m_MinimaList.end() || (*m_CurrentLM).Y != Y) return false; + locMin = &(*m_CurrentLM); + ++m_CurrentLM; + return true; +} +//------------------------------------------------------------------------------ + +IntRect ClipperBase::GetBounds() +{ + IntRect result; + MinimaList::iterator lm = m_MinimaList.begin(); + if (lm == m_MinimaList.end()) + { + result.left = result.top = result.right = result.bottom = 0; + return result; + } + result.left = lm->LeftBound->Bot.X; + result.top = lm->LeftBound->Bot.Y; + result.right = lm->LeftBound->Bot.X; + result.bottom = lm->LeftBound->Bot.Y; + while (lm != m_MinimaList.end()) + { + //todo - needs fixing for open paths + result.bottom = std::max(result.bottom, lm->LeftBound->Bot.Y); + TEdge* e = lm->LeftBound; + for (;;) { + TEdge* bottomE = e; + while (e->NextInLML) + { + if (e->Bot.X < result.left) result.left = e->Bot.X; + if (e->Bot.X > result.right) result.right = e->Bot.X; + e = e->NextInLML; + } + result.left = std::min(result.left, e->Bot.X); + result.right = std::max(result.right, e->Bot.X); + result.left = std::min(result.left, e->Top.X); + result.right = std::max(result.right, e->Top.X); + result.top = std::min(result.top, e->Top.Y); + if (bottomE == lm->LeftBound) e = lm->RightBound; + else break; + } + ++lm; + } + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::InsertScanbeam(const cInt Y) +{ + m_Scanbeam.push(Y); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::PopScanbeam(cInt &Y) +{ + if (m_Scanbeam.empty()) return false; + Y = m_Scanbeam.top(); + m_Scanbeam.pop(); + while (!m_Scanbeam.empty() && Y == m_Scanbeam.top()) { m_Scanbeam.pop(); } // Pop duplicates. + return true; +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeAllOutRecs(){ + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + DisposeOutRec(i); + m_PolyOuts.clear(); +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeOutRec(PolyOutList::size_type index) +{ + OutRec *outRec = m_PolyOuts[index]; + if (outRec->Pts) DisposeOutPts(outRec->Pts); + delete outRec; + m_PolyOuts[index] = 0; +} +//------------------------------------------------------------------------------ + +void ClipperBase::DeleteFromAEL(TEdge *e) +{ + TEdge* AelPrev = e->PrevInAEL; + TEdge* AelNext = e->NextInAEL; + if (!AelPrev && !AelNext && (e != m_ActiveEdges)) return; //already deleted + if (AelPrev) AelPrev->NextInAEL = AelNext; + else m_ActiveEdges = AelNext; + if (AelNext) AelNext->PrevInAEL = AelPrev; + e->NextInAEL = 0; + e->PrevInAEL = 0; +} +//------------------------------------------------------------------------------ + +OutRec* ClipperBase::CreateOutRec() +{ + OutRec* result = new OutRec; + result->IsHole = false; + result->IsOpen = false; + result->FirstLeft = 0; + result->Pts = 0; + result->BottomPt = 0; + result->PolyNd = 0; + m_PolyOuts.push_back(result); + result->Idx = (int)m_PolyOuts.size() - 1; + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::SwapPositionsInAEL(TEdge *Edge1, TEdge *Edge2) +{ + //check that one or other edge hasn't already been removed from AEL ... + if (Edge1->NextInAEL == Edge1->PrevInAEL || + Edge2->NextInAEL == Edge2->PrevInAEL) return; + + if (Edge1->NextInAEL == Edge2) + { + TEdge* Next = Edge2->NextInAEL; + if (Next) Next->PrevInAEL = Edge1; + TEdge* Prev = Edge1->PrevInAEL; + if (Prev) Prev->NextInAEL = Edge2; + Edge2->PrevInAEL = Prev; + Edge2->NextInAEL = Edge1; + Edge1->PrevInAEL = Edge2; + Edge1->NextInAEL = Next; + } + else if (Edge2->NextInAEL == Edge1) + { + TEdge* Next = Edge1->NextInAEL; + if (Next) Next->PrevInAEL = Edge2; + TEdge* Prev = Edge2->PrevInAEL; + if (Prev) Prev->NextInAEL = Edge1; + Edge1->PrevInAEL = Prev; + Edge1->NextInAEL = Edge2; + Edge2->PrevInAEL = Edge1; + Edge2->NextInAEL = Next; + } + else + { + TEdge* Next = Edge1->NextInAEL; + TEdge* Prev = Edge1->PrevInAEL; + Edge1->NextInAEL = Edge2->NextInAEL; + if (Edge1->NextInAEL) Edge1->NextInAEL->PrevInAEL = Edge1; + Edge1->PrevInAEL = Edge2->PrevInAEL; + if (Edge1->PrevInAEL) Edge1->PrevInAEL->NextInAEL = Edge1; + Edge2->NextInAEL = Next; + if (Edge2->NextInAEL) Edge2->NextInAEL->PrevInAEL = Edge2; + Edge2->PrevInAEL = Prev; + if (Edge2->PrevInAEL) Edge2->PrevInAEL->NextInAEL = Edge2; + } + + if (!Edge1->PrevInAEL) m_ActiveEdges = Edge1; + else if (!Edge2->PrevInAEL) m_ActiveEdges = Edge2; +} +//------------------------------------------------------------------------------ + +void ClipperBase::UpdateEdgeIntoAEL(TEdge *&e) +{ + if (!e->NextInLML) + throw clipperException("UpdateEdgeIntoAEL: invalid call"); + + e->NextInLML->OutIdx = e->OutIdx; + TEdge* AelPrev = e->PrevInAEL; + TEdge* AelNext = e->NextInAEL; + if (AelPrev) AelPrev->NextInAEL = e->NextInLML; + else m_ActiveEdges = e->NextInLML; + if (AelNext) AelNext->PrevInAEL = e->NextInLML; + e->NextInLML->Side = e->Side; + e->NextInLML->WindDelta = e->WindDelta; + e->NextInLML->WindCnt = e->WindCnt; + e->NextInLML->WindCnt2 = e->WindCnt2; + e = e->NextInLML; + e->Curr = e->Bot; + e->PrevInAEL = AelPrev; + e->NextInAEL = AelNext; + if (!IsHorizontal(*e)) InsertScanbeam(e->Top.Y); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::LocalMinimaPending() +{ + return (m_CurrentLM != m_MinimaList.end()); +} + +//------------------------------------------------------------------------------ +// TClipper methods ... +//------------------------------------------------------------------------------ + +Clipper::Clipper(int initOptions) : ClipperBase() //constructor +{ + m_ExecuteLocked = false; + m_UseFullRange = false; + m_ReverseOutput = ((initOptions & ioReverseSolution) != 0); + m_StrictSimple = ((initOptions & ioStrictlySimple) != 0); + m_PreserveCollinear = ((initOptions & ioPreserveCollinear) != 0); + m_HasOpenPaths = false; +#ifdef use_xyz + m_ZFill = 0; +#endif +} +//------------------------------------------------------------------------------ + +#ifdef use_xyz +void Clipper::ZFillFunction(ZFillCallback zFillFunc) +{ + m_ZFill = zFillFunc; +} +//------------------------------------------------------------------------------ +#endif + +bool Clipper::Execute(ClipType clipType, Paths &solution, PolyFillType fillType) +{ + return Execute(clipType, solution, fillType, fillType); +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, PolyTree &polytree, PolyFillType fillType) +{ + return Execute(clipType, polytree, fillType, fillType); +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, Paths &solution, + PolyFillType subjFillType, PolyFillType clipFillType) +{ + if( m_ExecuteLocked ) return false; + if (m_HasOpenPaths) + throw clipperException("Error: PolyTree struct is needed for open path clipping."); + m_ExecuteLocked = true; + solution.resize(0); + m_SubjFillType = subjFillType; + m_ClipFillType = clipFillType; + m_ClipType = clipType; + m_UsingPolyTree = false; + bool succeeded = ExecuteInternal(); + if (succeeded) BuildResult(solution); + DisposeAllOutRecs(); + m_ExecuteLocked = false; + return succeeded; +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, PolyTree& polytree, + PolyFillType subjFillType, PolyFillType clipFillType) +{ + if( m_ExecuteLocked ) return false; + m_ExecuteLocked = true; + m_SubjFillType = subjFillType; + m_ClipFillType = clipFillType; + m_ClipType = clipType; + m_UsingPolyTree = true; + bool succeeded = ExecuteInternal(); + if (succeeded) BuildResult2(polytree); + DisposeAllOutRecs(); + m_ExecuteLocked = false; + return succeeded; +} +//------------------------------------------------------------------------------ + +void Clipper::FixHoleLinkage(OutRec &outrec) +{ + //skip OutRecs that (a) contain outermost polygons or + //(b) already have the correct owner/child linkage ... + if (!outrec.FirstLeft || + (outrec.IsHole != outrec.FirstLeft->IsHole && + outrec.FirstLeft->Pts)) return; + + OutRec* orfl = outrec.FirstLeft; + while (orfl && ((orfl->IsHole == outrec.IsHole) || !orfl->Pts)) + orfl = orfl->FirstLeft; + outrec.FirstLeft = orfl; +} +//------------------------------------------------------------------------------ + +bool Clipper::ExecuteInternal() +{ + bool succeeded = true; + try { + Reset(); + m_Maxima = MaximaList(); + m_SortedEdges = 0; + + succeeded = true; + cInt botY, topY; + if (!PopScanbeam(botY)) return false; + InsertLocalMinimaIntoAEL(botY); + while (PopScanbeam(topY) || LocalMinimaPending()) + { + ProcessHorizontals(); + ClearGhostJoins(); + if (!ProcessIntersections(topY)) + { + succeeded = false; + break; + } + ProcessEdgesAtTopOfScanbeam(topY); + botY = topY; + InsertLocalMinimaIntoAEL(botY); + } + } + catch(...) + { + succeeded = false; + } + + if (succeeded) + { + //fix orientations ... + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec *outRec = m_PolyOuts[i]; + if (!outRec->Pts || outRec->IsOpen) continue; + if ((outRec->IsHole ^ m_ReverseOutput) == (Area(*outRec) > 0)) + ReversePolyPtLinks(outRec->Pts); + } + + if (!m_Joins.empty()) JoinCommonEdges(); + + //unfortunately FixupOutPolygon() must be done after JoinCommonEdges() + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec *outRec = m_PolyOuts[i]; + if (!outRec->Pts) continue; + if (outRec->IsOpen) + FixupOutPolyline(*outRec); + else + FixupOutPolygon(*outRec); + } + + if (m_StrictSimple) DoSimplePolygons(); + } + + ClearJoins(); + ClearGhostJoins(); + return succeeded; +} +//------------------------------------------------------------------------------ + +void Clipper::SetWindingCount(TEdge &edge) +{ + TEdge *e = edge.PrevInAEL; + //find the edge of the same polytype that immediately preceeds 'edge' in AEL + while (e && ((e->PolyTyp != edge.PolyTyp) || (e->WindDelta == 0))) e = e->PrevInAEL; + if (!e) + { + if (edge.WindDelta == 0) + { + PolyFillType pft = (edge.PolyTyp == ptSubject ? m_SubjFillType : m_ClipFillType); + edge.WindCnt = (pft == pftNegative ? -1 : 1); + } + else + edge.WindCnt = edge.WindDelta; + edge.WindCnt2 = 0; + e = m_ActiveEdges; //ie get ready to calc WindCnt2 + } + else if (edge.WindDelta == 0 && m_ClipType != ctUnion) + { + edge.WindCnt = 1; + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; //ie get ready to calc WindCnt2 + } + else if (IsEvenOddFillType(edge)) + { + //EvenOdd filling ... + if (edge.WindDelta == 0) + { + //are we inside a subj polygon ... + bool Inside = true; + TEdge *e2 = e->PrevInAEL; + while (e2) + { + if (e2->PolyTyp == e->PolyTyp && e2->WindDelta != 0) + Inside = !Inside; + e2 = e2->PrevInAEL; + } + edge.WindCnt = (Inside ? 0 : 1); + } + else + { + edge.WindCnt = edge.WindDelta; + } + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; //ie get ready to calc WindCnt2 + } + else + { + //nonZero, Positive or Negative filling ... + if (e->WindCnt * e->WindDelta < 0) + { + //prev edge is 'decreasing' WindCount (WC) toward zero + //so we're outside the previous polygon ... + if (Abs(e->WindCnt) > 1) + { + //outside prev poly but still inside another. + //when reversing direction of prev poly use the same WC + if (e->WindDelta * edge.WindDelta < 0) edge.WindCnt = e->WindCnt; + //otherwise continue to 'decrease' WC ... + else edge.WindCnt = e->WindCnt + edge.WindDelta; + } + else + //now outside all polys of same polytype so set own WC ... + edge.WindCnt = (edge.WindDelta == 0 ? 1 : edge.WindDelta); + } else + { + //prev edge is 'increasing' WindCount (WC) away from zero + //so we're inside the previous polygon ... + if (edge.WindDelta == 0) + edge.WindCnt = (e->WindCnt < 0 ? e->WindCnt - 1 : e->WindCnt + 1); + //if wind direction is reversing prev then use same WC + else if (e->WindDelta * edge.WindDelta < 0) edge.WindCnt = e->WindCnt; + //otherwise add to WC ... + else edge.WindCnt = e->WindCnt + edge.WindDelta; + } + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; //ie get ready to calc WindCnt2 + } + + //update WindCnt2 ... + if (IsEvenOddAltFillType(edge)) + { + //EvenOdd filling ... + while (e != &edge) + { + if (e->WindDelta != 0) + edge.WindCnt2 = (edge.WindCnt2 == 0 ? 1 : 0); + e = e->NextInAEL; + } + } else + { + //nonZero, Positive or Negative filling ... + while ( e != &edge ) + { + edge.WindCnt2 += e->WindDelta; + e = e->NextInAEL; + } + } +} +//------------------------------------------------------------------------------ + +bool Clipper::IsEvenOddFillType(const TEdge& edge) const +{ + if (edge.PolyTyp == ptSubject) + return m_SubjFillType == pftEvenOdd; else + return m_ClipFillType == pftEvenOdd; +} +//------------------------------------------------------------------------------ + +bool Clipper::IsEvenOddAltFillType(const TEdge& edge) const +{ + if (edge.PolyTyp == ptSubject) + return m_ClipFillType == pftEvenOdd; else + return m_SubjFillType == pftEvenOdd; +} +//------------------------------------------------------------------------------ + +bool Clipper::IsContributing(const TEdge& edge) const +{ + PolyFillType pft, pft2; + if (edge.PolyTyp == ptSubject) + { + pft = m_SubjFillType; + pft2 = m_ClipFillType; + } else + { + pft = m_ClipFillType; + pft2 = m_SubjFillType; + } + + switch(pft) + { + case pftEvenOdd: + //return false if a subj line has been flagged as inside a subj polygon + if (edge.WindDelta == 0 && edge.WindCnt != 1) return false; + break; + case pftNonZero: + if (Abs(edge.WindCnt) != 1) return false; + break; + case pftPositive: + if (edge.WindCnt != 1) return false; + break; + default: //pftNegative + if (edge.WindCnt != -1) return false; + } + + switch(m_ClipType) + { + case ctIntersection: + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 != 0); + case pftPositive: + return (edge.WindCnt2 > 0); + default: + return (edge.WindCnt2 < 0); + } + break; + case ctUnion: + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + break; + case ctDifference: + if (edge.PolyTyp == ptSubject) + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + else + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 != 0); + case pftPositive: + return (edge.WindCnt2 > 0); + default: + return (edge.WindCnt2 < 0); + } + break; + case ctXor: + if (edge.WindDelta == 0) //XOr always contributing unless open + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + else + return true; + break; + default: + return true; + } +} +//------------------------------------------------------------------------------ + +OutPt* Clipper::AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) +{ + OutPt* result; + TEdge *e, *prevE; + if (IsHorizontal(*e2) || ( e1->Dx > e2->Dx )) + { + result = AddOutPt(e1, Pt); + e2->OutIdx = e1->OutIdx; + e1->Side = esLeft; + e2->Side = esRight; + e = e1; + if (e->PrevInAEL == e2) + prevE = e2->PrevInAEL; + else + prevE = e->PrevInAEL; + } else + { + result = AddOutPt(e2, Pt); + e1->OutIdx = e2->OutIdx; + e1->Side = esRight; + e2->Side = esLeft; + e = e2; + if (e->PrevInAEL == e1) + prevE = e1->PrevInAEL; + else + prevE = e->PrevInAEL; + } + + if (prevE && prevE->OutIdx >= 0) + { + cInt xPrev = TopX(*prevE, Pt.Y); + cInt xE = TopX(*e, Pt.Y); + if (xPrev == xE && (e->WindDelta != 0) && (prevE->WindDelta != 0) && + SlopesEqual(IntPoint(xPrev, Pt.Y), prevE->Top, IntPoint(xE, Pt.Y), e->Top, m_UseFullRange)) + { + OutPt* outPt = AddOutPt(prevE, Pt); + AddJoin(result, outPt, e->Top); + } + } + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) +{ + AddOutPt( e1, Pt ); + if (e2->WindDelta == 0) AddOutPt(e2, Pt); + if( e1->OutIdx == e2->OutIdx ) + { + e1->OutIdx = Unassigned; + e2->OutIdx = Unassigned; + } + else if (e1->OutIdx < e2->OutIdx) + AppendPolygon(e1, e2); + else + AppendPolygon(e2, e1); +} +//------------------------------------------------------------------------------ + +void Clipper::AddEdgeToSEL(TEdge *edge) +{ + //SEL pointers in PEdge are reused to build a list of horizontal edges. + //However, we don't need to worry about order with horizontal edge processing. + if( !m_SortedEdges ) + { + m_SortedEdges = edge; + edge->PrevInSEL = 0; + edge->NextInSEL = 0; + } + else + { + edge->NextInSEL = m_SortedEdges; + edge->PrevInSEL = 0; + m_SortedEdges->PrevInSEL = edge; + m_SortedEdges = edge; + } +} +//------------------------------------------------------------------------------ + +bool Clipper::PopEdgeFromSEL(TEdge *&edge) +{ + if (!m_SortedEdges) return false; + edge = m_SortedEdges; + DeleteFromSEL(m_SortedEdges); + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::CopyAELToSEL() +{ + TEdge* e = m_ActiveEdges; + m_SortedEdges = e; + while ( e ) + { + e->PrevInSEL = e->PrevInAEL; + e->NextInSEL = e->NextInAEL; + e = e->NextInAEL; + } +} +//------------------------------------------------------------------------------ + +void Clipper::AddJoin(OutPt *op1, OutPt *op2, const IntPoint OffPt) +{ + Join* j = new Join; + j->OutPt1 = op1; + j->OutPt2 = op2; + j->OffPt = OffPt; + m_Joins.push_back(j); +} +//------------------------------------------------------------------------------ + +void Clipper::ClearJoins() +{ + for (JoinList::size_type i = 0; i < m_Joins.size(); i++) + delete m_Joins[i]; + m_Joins.resize(0); +} +//------------------------------------------------------------------------------ + +void Clipper::ClearGhostJoins() +{ + for (JoinList::size_type i = 0; i < m_GhostJoins.size(); i++) + delete m_GhostJoins[i]; + m_GhostJoins.resize(0); +} +//------------------------------------------------------------------------------ + +void Clipper::AddGhostJoin(OutPt *op, const IntPoint OffPt) +{ + Join* j = new Join; + j->OutPt1 = op; + j->OutPt2 = 0; + j->OffPt = OffPt; + m_GhostJoins.push_back(j); +} +//------------------------------------------------------------------------------ + +void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) +{ + const LocalMinimum *lm; + while (PopLocalMinima(botY, lm)) + { + TEdge* lb = lm->LeftBound; + TEdge* rb = lm->RightBound; + + OutPt *Op1 = 0; + if (!lb) + { + //nb: don't insert LB into either AEL or SEL + InsertEdgeIntoAEL(rb, 0); + SetWindingCount(*rb); + if (IsContributing(*rb)) + Op1 = AddOutPt(rb, rb->Bot); + } + else if (!rb) + { + InsertEdgeIntoAEL(lb, 0); + SetWindingCount(*lb); + if (IsContributing(*lb)) + Op1 = AddOutPt(lb, lb->Bot); + InsertScanbeam(lb->Top.Y); + } + else + { + InsertEdgeIntoAEL(lb, 0); + InsertEdgeIntoAEL(rb, lb); + SetWindingCount( *lb ); + rb->WindCnt = lb->WindCnt; + rb->WindCnt2 = lb->WindCnt2; + if (IsContributing(*lb)) + Op1 = AddLocalMinPoly(lb, rb, lb->Bot); + InsertScanbeam(lb->Top.Y); + } + + if (rb) + { + if (IsHorizontal(*rb)) + { + AddEdgeToSEL(rb); + if (rb->NextInLML) + InsertScanbeam(rb->NextInLML->Top.Y); + } + else InsertScanbeam( rb->Top.Y ); + } + + if (!lb || !rb) continue; + + //if any output polygons share an edge, they'll need joining later ... + if (Op1 && IsHorizontal(*rb) && + m_GhostJoins.size() > 0 && (rb->WindDelta != 0)) + { + for (JoinList::size_type i = 0; i < m_GhostJoins.size(); ++i) + { + Join* jr = m_GhostJoins[i]; + //if the horizontal Rb and a 'ghost' horizontal overlap, then convert + //the 'ghost' join to a real join ready for later ... + if (HorzSegmentsOverlap(jr->OutPt1->Pt.X, jr->OffPt.X, rb->Bot.X, rb->Top.X)) + AddJoin(jr->OutPt1, Op1, jr->OffPt); + } + } + + if (lb->OutIdx >= 0 && lb->PrevInAEL && + lb->PrevInAEL->Curr.X == lb->Bot.X && + lb->PrevInAEL->OutIdx >= 0 && + SlopesEqual(lb->PrevInAEL->Bot, lb->PrevInAEL->Top, lb->Curr, lb->Top, m_UseFullRange) && + (lb->WindDelta != 0) && (lb->PrevInAEL->WindDelta != 0)) + { + OutPt *Op2 = AddOutPt(lb->PrevInAEL, lb->Bot); + AddJoin(Op1, Op2, lb->Top); + } + + if(lb->NextInAEL != rb) + { + + if (rb->OutIdx >= 0 && rb->PrevInAEL->OutIdx >= 0 && + SlopesEqual(rb->PrevInAEL->Curr, rb->PrevInAEL->Top, rb->Curr, rb->Top, m_UseFullRange) && + (rb->WindDelta != 0) && (rb->PrevInAEL->WindDelta != 0)) + { + OutPt *Op2 = AddOutPt(rb->PrevInAEL, rb->Bot); + AddJoin(Op1, Op2, rb->Top); + } + + TEdge* e = lb->NextInAEL; + if (e) + { + while( e != rb ) + { + //nb: For calculating winding counts etc, IntersectEdges() assumes + //that param1 will be to the Right of param2 ABOVE the intersection ... + IntersectEdges(rb , e , lb->Curr); //order important here + e = e->NextInAEL; + } + } + } + + } +} +//------------------------------------------------------------------------------ + +void Clipper::DeleteFromSEL(TEdge *e) +{ + TEdge* SelPrev = e->PrevInSEL; + TEdge* SelNext = e->NextInSEL; + if( !SelPrev && !SelNext && (e != m_SortedEdges) ) return; //already deleted + if( SelPrev ) SelPrev->NextInSEL = SelNext; + else m_SortedEdges = SelNext; + if( SelNext ) SelNext->PrevInSEL = SelPrev; + e->NextInSEL = 0; + e->PrevInSEL = 0; +} +//------------------------------------------------------------------------------ + +#ifdef use_xyz +void Clipper::SetZ(IntPoint& pt, TEdge& e1, TEdge& e2) +{ + if (pt.Z != 0 || !m_ZFill) return; + else if (pt == e1.Bot) pt.Z = e1.Bot.Z; + else if (pt == e1.Top) pt.Z = e1.Top.Z; + else if (pt == e2.Bot) pt.Z = e2.Bot.Z; + else if (pt == e2.Top) pt.Z = e2.Top.Z; + else (*m_ZFill)(e1.Bot, e1.Top, e2.Bot, e2.Top, pt); +} +//------------------------------------------------------------------------------ +#endif + +void Clipper::IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &Pt) +{ + bool e1Contributing = ( e1->OutIdx >= 0 ); + bool e2Contributing = ( e2->OutIdx >= 0 ); + +#ifdef use_xyz + SetZ(Pt, *e1, *e2); +#endif + +#ifdef use_lines + //if either edge is on an OPEN path ... + if (e1->WindDelta == 0 || e2->WindDelta == 0) + { + //ignore subject-subject open path intersections UNLESS they + //are both open paths, AND they are both 'contributing maximas' ... + if (e1->WindDelta == 0 && e2->WindDelta == 0) return; + + //if intersecting a subj line with a subj poly ... + else if (e1->PolyTyp == e2->PolyTyp && + e1->WindDelta != e2->WindDelta && m_ClipType == ctUnion) + { + if (e1->WindDelta == 0) + { + if (e2Contributing) + { + AddOutPt(e1, Pt); + if (e1Contributing) e1->OutIdx = Unassigned; + } + } + else + { + if (e1Contributing) + { + AddOutPt(e2, Pt); + if (e2Contributing) e2->OutIdx = Unassigned; + } + } + } + else if (e1->PolyTyp != e2->PolyTyp) + { + //toggle subj open path OutIdx on/off when Abs(clip.WndCnt) == 1 ... + if ((e1->WindDelta == 0) && abs(e2->WindCnt) == 1 && + (m_ClipType != ctUnion || e2->WindCnt2 == 0)) + { + AddOutPt(e1, Pt); + if (e1Contributing) e1->OutIdx = Unassigned; + } + else if ((e2->WindDelta == 0) && (abs(e1->WindCnt) == 1) && + (m_ClipType != ctUnion || e1->WindCnt2 == 0)) + { + AddOutPt(e2, Pt); + if (e2Contributing) e2->OutIdx = Unassigned; + } + } + return; + } +#endif + + //update winding counts... + //assumes that e1 will be to the Right of e2 ABOVE the intersection + if ( e1->PolyTyp == e2->PolyTyp ) + { + if ( IsEvenOddFillType( *e1) ) + { + int oldE1WindCnt = e1->WindCnt; + e1->WindCnt = e2->WindCnt; + e2->WindCnt = oldE1WindCnt; + } else + { + if (e1->WindCnt + e2->WindDelta == 0 ) e1->WindCnt = -e1->WindCnt; + else e1->WindCnt += e2->WindDelta; + if ( e2->WindCnt - e1->WindDelta == 0 ) e2->WindCnt = -e2->WindCnt; + else e2->WindCnt -= e1->WindDelta; + } + } else + { + if (!IsEvenOddFillType(*e2)) e1->WindCnt2 += e2->WindDelta; + else e1->WindCnt2 = ( e1->WindCnt2 == 0 ) ? 1 : 0; + if (!IsEvenOddFillType(*e1)) e2->WindCnt2 -= e1->WindDelta; + else e2->WindCnt2 = ( e2->WindCnt2 == 0 ) ? 1 : 0; + } + + PolyFillType e1FillType, e2FillType, e1FillType2, e2FillType2; + if (e1->PolyTyp == ptSubject) + { + e1FillType = m_SubjFillType; + e1FillType2 = m_ClipFillType; + } else + { + e1FillType = m_ClipFillType; + e1FillType2 = m_SubjFillType; + } + if (e2->PolyTyp == ptSubject) + { + e2FillType = m_SubjFillType; + e2FillType2 = m_ClipFillType; + } else + { + e2FillType = m_ClipFillType; + e2FillType2 = m_SubjFillType; + } + + cInt e1Wc, e2Wc; + switch (e1FillType) + { + case pftPositive: e1Wc = e1->WindCnt; break; + case pftNegative: e1Wc = -e1->WindCnt; break; + default: e1Wc = Abs(e1->WindCnt); + } + switch(e2FillType) + { + case pftPositive: e2Wc = e2->WindCnt; break; + case pftNegative: e2Wc = -e2->WindCnt; break; + default: e2Wc = Abs(e2->WindCnt); + } + + if ( e1Contributing && e2Contributing ) + { + if ((e1Wc != 0 && e1Wc != 1) || (e2Wc != 0 && e2Wc != 1) || + (e1->PolyTyp != e2->PolyTyp && m_ClipType != ctXor) ) + { + AddLocalMaxPoly(e1, e2, Pt); + } + else + { + AddOutPt(e1, Pt); + AddOutPt(e2, Pt); + SwapSides( *e1 , *e2 ); + SwapPolyIndexes( *e1 , *e2 ); + } + } + else if ( e1Contributing ) + { + if (e2Wc == 0 || e2Wc == 1) + { + AddOutPt(e1, Pt); + SwapSides(*e1, *e2); + SwapPolyIndexes(*e1, *e2); + } + } + else if ( e2Contributing ) + { + if (e1Wc == 0 || e1Wc == 1) + { + AddOutPt(e2, Pt); + SwapSides(*e1, *e2); + SwapPolyIndexes(*e1, *e2); + } + } + else if ( (e1Wc == 0 || e1Wc == 1) && (e2Wc == 0 || e2Wc == 1)) + { + //neither edge is currently contributing ... + + cInt e1Wc2, e2Wc2; + switch (e1FillType2) + { + case pftPositive: e1Wc2 = e1->WindCnt2; break; + case pftNegative : e1Wc2 = -e1->WindCnt2; break; + default: e1Wc2 = Abs(e1->WindCnt2); + } + switch (e2FillType2) + { + case pftPositive: e2Wc2 = e2->WindCnt2; break; + case pftNegative: e2Wc2 = -e2->WindCnt2; break; + default: e2Wc2 = Abs(e2->WindCnt2); + } + + if (e1->PolyTyp != e2->PolyTyp) + { + AddLocalMinPoly(e1, e2, Pt); + } + else if (e1Wc == 1 && e2Wc == 1) + switch( m_ClipType ) { + case ctIntersection: + if (e1Wc2 > 0 && e2Wc2 > 0) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctUnion: + if ( e1Wc2 <= 0 && e2Wc2 <= 0 ) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctDifference: + if (((e1->PolyTyp == ptClip) && (e1Wc2 > 0) && (e2Wc2 > 0)) || + ((e1->PolyTyp == ptSubject) && (e1Wc2 <= 0) && (e2Wc2 <= 0))) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctXor: + AddLocalMinPoly(e1, e2, Pt); + } + else + SwapSides( *e1, *e2 ); + } +} +//------------------------------------------------------------------------------ + +void Clipper::SetHoleState(TEdge *e, OutRec *outrec) +{ + TEdge *e2 = e->PrevInAEL; + TEdge *eTmp = 0; + while (e2) + { + if (e2->OutIdx >= 0 && e2->WindDelta != 0) + { + if (!eTmp) eTmp = e2; + else if (eTmp->OutIdx == e2->OutIdx) eTmp = 0; + } + e2 = e2->PrevInAEL; + } + if (!eTmp) + { + outrec->FirstLeft = 0; + outrec->IsHole = false; + } + else + { + outrec->FirstLeft = m_PolyOuts[eTmp->OutIdx]; + outrec->IsHole = !outrec->FirstLeft->IsHole; + } +} +//------------------------------------------------------------------------------ + +OutRec* GetLowermostRec(OutRec *outRec1, OutRec *outRec2) +{ + //work out which polygon fragment has the correct hole state ... + if (!outRec1->BottomPt) + outRec1->BottomPt = GetBottomPt(outRec1->Pts); + if (!outRec2->BottomPt) + outRec2->BottomPt = GetBottomPt(outRec2->Pts); + OutPt *OutPt1 = outRec1->BottomPt; + OutPt *OutPt2 = outRec2->BottomPt; + if (OutPt1->Pt.Y > OutPt2->Pt.Y) return outRec1; + else if (OutPt1->Pt.Y < OutPt2->Pt.Y) return outRec2; + else if (OutPt1->Pt.X < OutPt2->Pt.X) return outRec1; + else if (OutPt1->Pt.X > OutPt2->Pt.X) return outRec2; + else if (OutPt1->Next == OutPt1) return outRec2; + else if (OutPt2->Next == OutPt2) return outRec1; + else if (FirstIsBottomPt(OutPt1, OutPt2)) return outRec1; + else return outRec2; +} +//------------------------------------------------------------------------------ + +bool OutRec1RightOfOutRec2(OutRec* outRec1, OutRec* outRec2) +{ + do + { + outRec1 = outRec1->FirstLeft; + if (outRec1 == outRec2) return true; + } while (outRec1); + return false; +} +//------------------------------------------------------------------------------ + +OutRec* Clipper::GetOutRec(int Idx) +{ + OutRec* outrec = m_PolyOuts[Idx]; + while (outrec != m_PolyOuts[outrec->Idx]) + outrec = m_PolyOuts[outrec->Idx]; + return outrec; +} +//------------------------------------------------------------------------------ + +void Clipper::AppendPolygon(TEdge *e1, TEdge *e2) +{ + //get the start and ends of both output polygons ... + OutRec *outRec1 = m_PolyOuts[e1->OutIdx]; + OutRec *outRec2 = m_PolyOuts[e2->OutIdx]; + + OutRec *holeStateRec; + if (OutRec1RightOfOutRec2(outRec1, outRec2)) + holeStateRec = outRec2; + else if (OutRec1RightOfOutRec2(outRec2, outRec1)) + holeStateRec = outRec1; + else + holeStateRec = GetLowermostRec(outRec1, outRec2); + + //get the start and ends of both output polygons and + //join e2 poly onto e1 poly and delete pointers to e2 ... + + OutPt* p1_lft = outRec1->Pts; + OutPt* p1_rt = p1_lft->Prev; + OutPt* p2_lft = outRec2->Pts; + OutPt* p2_rt = p2_lft->Prev; + + //join e2 poly onto e1 poly and delete pointers to e2 ... + if( e1->Side == esLeft ) + { + if( e2->Side == esLeft ) + { + //z y x a b c + ReversePolyPtLinks(p2_lft); + p2_lft->Next = p1_lft; + p1_lft->Prev = p2_lft; + p1_rt->Next = p2_rt; + p2_rt->Prev = p1_rt; + outRec1->Pts = p2_rt; + } else + { + //x y z a b c + p2_rt->Next = p1_lft; + p1_lft->Prev = p2_rt; + p2_lft->Prev = p1_rt; + p1_rt->Next = p2_lft; + outRec1->Pts = p2_lft; + } + } else + { + if( e2->Side == esRight ) + { + //a b c z y x + ReversePolyPtLinks(p2_lft); + p1_rt->Next = p2_rt; + p2_rt->Prev = p1_rt; + p2_lft->Next = p1_lft; + p1_lft->Prev = p2_lft; + } else + { + //a b c x y z + p1_rt->Next = p2_lft; + p2_lft->Prev = p1_rt; + p1_lft->Prev = p2_rt; + p2_rt->Next = p1_lft; + } + } + + outRec1->BottomPt = 0; + if (holeStateRec == outRec2) + { + if (outRec2->FirstLeft != outRec1) + outRec1->FirstLeft = outRec2->FirstLeft; + outRec1->IsHole = outRec2->IsHole; + } + outRec2->Pts = 0; + outRec2->BottomPt = 0; + outRec2->FirstLeft = outRec1; + + int OKIdx = e1->OutIdx; + int ObsoleteIdx = e2->OutIdx; + + e1->OutIdx = Unassigned; //nb: safe because we only get here via AddLocalMaxPoly + e2->OutIdx = Unassigned; + + TEdge* e = m_ActiveEdges; + while( e ) + { + if( e->OutIdx == ObsoleteIdx ) + { + e->OutIdx = OKIdx; + e->Side = e1->Side; + break; + } + e = e->NextInAEL; + } + + outRec2->Idx = outRec1->Idx; +} +//------------------------------------------------------------------------------ + +OutPt* Clipper::AddOutPt(TEdge *e, const IntPoint &pt) +{ + if( e->OutIdx < 0 ) + { + OutRec *outRec = CreateOutRec(); + outRec->IsOpen = (e->WindDelta == 0); + OutPt* newOp = new OutPt; + outRec->Pts = newOp; + newOp->Idx = outRec->Idx; + newOp->Pt = pt; + newOp->Next = newOp; + newOp->Prev = newOp; + if (!outRec->IsOpen) + SetHoleState(e, outRec); + e->OutIdx = outRec->Idx; + return newOp; + } else + { + OutRec *outRec = m_PolyOuts[e->OutIdx]; + //OutRec.Pts is the 'Left-most' point & OutRec.Pts.Prev is the 'Right-most' + OutPt* op = outRec->Pts; + + bool ToFront = (e->Side == esLeft); + if (ToFront && (pt == op->Pt)) return op; + else if (!ToFront && (pt == op->Prev->Pt)) return op->Prev; + + OutPt* newOp = new OutPt; + newOp->Idx = outRec->Idx; + newOp->Pt = pt; + newOp->Next = op; + newOp->Prev = op->Prev; + newOp->Prev->Next = newOp; + op->Prev = newOp; + if (ToFront) outRec->Pts = newOp; + return newOp; + } +} +//------------------------------------------------------------------------------ + +OutPt* Clipper::GetLastOutPt(TEdge *e) +{ + OutRec *outRec = m_PolyOuts[e->OutIdx]; + if (e->Side == esLeft) + return outRec->Pts; + else + return outRec->Pts->Prev; +} +//------------------------------------------------------------------------------ + +void Clipper::ProcessHorizontals() +{ + TEdge* horzEdge; + while (PopEdgeFromSEL(horzEdge)) + ProcessHorizontal(horzEdge); +} +//------------------------------------------------------------------------------ + +inline bool IsMinima(TEdge *e) +{ + return e && (e->Prev->NextInLML != e) && (e->Next->NextInLML != e); +} +//------------------------------------------------------------------------------ + +inline bool IsMaxima(TEdge *e, const cInt Y) +{ + return e && e->Top.Y == Y && !e->NextInLML; +} +//------------------------------------------------------------------------------ + +inline bool IsIntermediate(TEdge *e, const cInt Y) +{ + return e->Top.Y == Y && e->NextInLML; +} +//------------------------------------------------------------------------------ + +TEdge *GetMaximaPair(TEdge *e) +{ + if ((e->Next->Top == e->Top) && !e->Next->NextInLML) + return e->Next; + else if ((e->Prev->Top == e->Top) && !e->Prev->NextInLML) + return e->Prev; + else return 0; +} +//------------------------------------------------------------------------------ + +TEdge *GetMaximaPairEx(TEdge *e) +{ + //as GetMaximaPair() but returns 0 if MaxPair isn't in AEL (unless it's horizontal) + TEdge* result = GetMaximaPair(e); + if (result && (result->OutIdx == Skip || + (result->NextInAEL == result->PrevInAEL && !IsHorizontal(*result)))) return 0; + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::SwapPositionsInSEL(TEdge *Edge1, TEdge *Edge2) +{ + if( !( Edge1->NextInSEL ) && !( Edge1->PrevInSEL ) ) return; + if( !( Edge2->NextInSEL ) && !( Edge2->PrevInSEL ) ) return; + + if( Edge1->NextInSEL == Edge2 ) + { + TEdge* Next = Edge2->NextInSEL; + if( Next ) Next->PrevInSEL = Edge1; + TEdge* Prev = Edge1->PrevInSEL; + if( Prev ) Prev->NextInSEL = Edge2; + Edge2->PrevInSEL = Prev; + Edge2->NextInSEL = Edge1; + Edge1->PrevInSEL = Edge2; + Edge1->NextInSEL = Next; + } + else if( Edge2->NextInSEL == Edge1 ) + { + TEdge* Next = Edge1->NextInSEL; + if( Next ) Next->PrevInSEL = Edge2; + TEdge* Prev = Edge2->PrevInSEL; + if( Prev ) Prev->NextInSEL = Edge1; + Edge1->PrevInSEL = Prev; + Edge1->NextInSEL = Edge2; + Edge2->PrevInSEL = Edge1; + Edge2->NextInSEL = Next; + } + else + { + TEdge* Next = Edge1->NextInSEL; + TEdge* Prev = Edge1->PrevInSEL; + Edge1->NextInSEL = Edge2->NextInSEL; + if( Edge1->NextInSEL ) Edge1->NextInSEL->PrevInSEL = Edge1; + Edge1->PrevInSEL = Edge2->PrevInSEL; + if( Edge1->PrevInSEL ) Edge1->PrevInSEL->NextInSEL = Edge1; + Edge2->NextInSEL = Next; + if( Edge2->NextInSEL ) Edge2->NextInSEL->PrevInSEL = Edge2; + Edge2->PrevInSEL = Prev; + if( Edge2->PrevInSEL ) Edge2->PrevInSEL->NextInSEL = Edge2; + } + + if( !Edge1->PrevInSEL ) m_SortedEdges = Edge1; + else if( !Edge2->PrevInSEL ) m_SortedEdges = Edge2; +} +//------------------------------------------------------------------------------ + +TEdge* GetNextInAEL(TEdge *e, Direction dir) +{ + return dir == dLeftToRight ? e->NextInAEL : e->PrevInAEL; +} +//------------------------------------------------------------------------------ + +void GetHorzDirection(TEdge& HorzEdge, Direction& Dir, cInt& Left, cInt& Right) +{ + if (HorzEdge.Bot.X < HorzEdge.Top.X) + { + Left = HorzEdge.Bot.X; + Right = HorzEdge.Top.X; + Dir = dLeftToRight; + } else + { + Left = HorzEdge.Top.X; + Right = HorzEdge.Bot.X; + Dir = dRightToLeft; + } +} +//------------------------------------------------------------------------ + +/******************************************************************************* +* Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or * +* Bottom of a scanbeam) are processed as if layered. The order in which HEs * +* are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] * +* (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), * +* and with other non-horizontal edges [*]. Once these intersections are * +* processed, intermediate HEs then 'promote' the Edge above (NextInLML) into * +* the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. * +*******************************************************************************/ + +void Clipper::ProcessHorizontal(TEdge *horzEdge) +{ + Direction dir; + cInt horzLeft, horzRight; + bool IsOpen = (horzEdge->WindDelta == 0); + + GetHorzDirection(*horzEdge, dir, horzLeft, horzRight); + + TEdge* eLastHorz = horzEdge, *eMaxPair = 0; + while (eLastHorz->NextInLML && IsHorizontal(*eLastHorz->NextInLML)) + eLastHorz = eLastHorz->NextInLML; + if (!eLastHorz->NextInLML) + eMaxPair = GetMaximaPair(eLastHorz); + + MaximaList::const_iterator maxIt; + MaximaList::const_reverse_iterator maxRit; + if (m_Maxima.size() > 0) + { + //get the first maxima in range (X) ... + if (dir == dLeftToRight) + { + maxIt = m_Maxima.begin(); + while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X) maxIt++; + if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X) + maxIt = m_Maxima.end(); + } + else + { + maxRit = m_Maxima.rbegin(); + while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X) maxRit++; + if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X) + maxRit = m_Maxima.rend(); + } + } + + OutPt* op1 = 0; + + for (;;) //loop through consec. horizontal edges + { + + bool IsLastHorz = (horzEdge == eLastHorz); + TEdge* e = GetNextInAEL(horzEdge, dir); + while(e) + { + + //this code block inserts extra coords into horizontal edges (in output + //polygons) whereever maxima touch these horizontal edges. This helps + //'simplifying' polygons (ie if the Simplify property is set). + if (m_Maxima.size() > 0) + { + if (dir == dLeftToRight) + { + while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) + { + if (horzEdge->OutIdx >= 0 && !IsOpen) + AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y)); + maxIt++; + } + } + else + { + while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) + { + if (horzEdge->OutIdx >= 0 && !IsOpen) + AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y)); + maxRit++; + } + } + }; + + if ((dir == dLeftToRight && e->Curr.X > horzRight) || + (dir == dRightToLeft && e->Curr.X < horzLeft)) break; + + //Also break if we've got to the end of an intermediate horizontal edge ... + //nb: Smaller Dx's are to the right of larger Dx's ABOVE the horizontal. + if (e->Curr.X == horzEdge->Top.X && horzEdge->NextInLML && + e->Dx < horzEdge->NextInLML->Dx) break; + + if (horzEdge->OutIdx >= 0 && !IsOpen) //note: may be done multiple times + { + op1 = AddOutPt(horzEdge, e->Curr); + TEdge* eNextHorz = m_SortedEdges; + while (eNextHorz) + { + if (eNextHorz->OutIdx >= 0 && + HorzSegmentsOverlap(horzEdge->Bot.X, + horzEdge->Top.X, eNextHorz->Bot.X, eNextHorz->Top.X)) + { + OutPt* op2 = GetLastOutPt(eNextHorz); + AddJoin(op2, op1, eNextHorz->Top); + } + eNextHorz = eNextHorz->NextInSEL; + } + AddGhostJoin(op1, horzEdge->Bot); + } + + //OK, so far we're still in range of the horizontal Edge but make sure + //we're at the last of consec. horizontals when matching with eMaxPair + if(e == eMaxPair && IsLastHorz) + { + if (horzEdge->OutIdx >= 0) + AddLocalMaxPoly(horzEdge, eMaxPair, horzEdge->Top); + DeleteFromAEL(horzEdge); + DeleteFromAEL(eMaxPair); + return; + } + + if(dir == dLeftToRight) + { + IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y); + IntersectEdges(horzEdge, e, Pt); + } + else + { + IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y); + IntersectEdges( e, horzEdge, Pt); + } + TEdge* eNext = GetNextInAEL(e, dir); + SwapPositionsInAEL( horzEdge, e ); + e = eNext; + } //end while(e) + + //Break out of loop if HorzEdge.NextInLML is not also horizontal ... + if (!horzEdge->NextInLML || !IsHorizontal(*horzEdge->NextInLML)) break; + + UpdateEdgeIntoAEL(horzEdge); + if (horzEdge->OutIdx >= 0) AddOutPt(horzEdge, horzEdge->Bot); + GetHorzDirection(*horzEdge, dir, horzLeft, horzRight); + + } //end for (;;) + + if (horzEdge->OutIdx >= 0 && !op1) + { + op1 = GetLastOutPt(horzEdge); + TEdge* eNextHorz = m_SortedEdges; + while (eNextHorz) + { + if (eNextHorz->OutIdx >= 0 && + HorzSegmentsOverlap(horzEdge->Bot.X, + horzEdge->Top.X, eNextHorz->Bot.X, eNextHorz->Top.X)) + { + OutPt* op2 = GetLastOutPt(eNextHorz); + AddJoin(op2, op1, eNextHorz->Top); + } + eNextHorz = eNextHorz->NextInSEL; + } + AddGhostJoin(op1, horzEdge->Top); + } + + if (horzEdge->NextInLML) + { + if(horzEdge->OutIdx >= 0) + { + op1 = AddOutPt( horzEdge, horzEdge->Top); + UpdateEdgeIntoAEL(horzEdge); + if (horzEdge->WindDelta == 0) return; + //nb: HorzEdge is no longer horizontal here + TEdge* ePrev = horzEdge->PrevInAEL; + TEdge* eNext = horzEdge->NextInAEL; + if (ePrev && ePrev->Curr.X == horzEdge->Bot.X && + ePrev->Curr.Y == horzEdge->Bot.Y && ePrev->WindDelta != 0 && + (ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y && + SlopesEqual(*horzEdge, *ePrev, m_UseFullRange))) + { + OutPt* op2 = AddOutPt(ePrev, horzEdge->Bot); + AddJoin(op1, op2, horzEdge->Top); + } + else if (eNext && eNext->Curr.X == horzEdge->Bot.X && + eNext->Curr.Y == horzEdge->Bot.Y && eNext->WindDelta != 0 && + eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y && + SlopesEqual(*horzEdge, *eNext, m_UseFullRange)) + { + OutPt* op2 = AddOutPt(eNext, horzEdge->Bot); + AddJoin(op1, op2, horzEdge->Top); + } + } + else + UpdateEdgeIntoAEL(horzEdge); + } + else + { + if (horzEdge->OutIdx >= 0) AddOutPt(horzEdge, horzEdge->Top); + DeleteFromAEL(horzEdge); + } +} +//------------------------------------------------------------------------------ + +bool Clipper::ProcessIntersections(const cInt topY) +{ + if( !m_ActiveEdges ) return true; + try { + BuildIntersectList(topY); + size_t IlSize = m_IntersectList.size(); + if (IlSize == 0) return true; + if (IlSize == 1 || FixupIntersectionOrder()) ProcessIntersectList(); + else return false; + } + catch(...) + { + m_SortedEdges = 0; + DisposeIntersectNodes(); + throw clipperException("ProcessIntersections error"); + } + m_SortedEdges = 0; + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::DisposeIntersectNodes() +{ + for (size_t i = 0; i < m_IntersectList.size(); ++i ) + delete m_IntersectList[i]; + m_IntersectList.clear(); +} +//------------------------------------------------------------------------------ + +void Clipper::BuildIntersectList(const cInt topY) +{ + if ( !m_ActiveEdges ) return; + + //prepare for sorting ... + TEdge* e = m_ActiveEdges; + m_SortedEdges = e; + while( e ) + { + e->PrevInSEL = e->PrevInAEL; + e->NextInSEL = e->NextInAEL; + e->Curr.X = TopX( *e, topY ); + e = e->NextInAEL; + } + + //bubblesort ... + bool isModified; + do + { + isModified = false; + e = m_SortedEdges; + while( e->NextInSEL ) + { + TEdge *eNext = e->NextInSEL; + IntPoint Pt; + if(e->Curr.X > eNext->Curr.X) + { + IntersectPoint(*e, *eNext, Pt); + if (Pt.Y < topY) Pt = IntPoint(TopX(*e, topY), topY); + IntersectNode * newNode = new IntersectNode; + newNode->Edge1 = e; + newNode->Edge2 = eNext; + newNode->Pt = Pt; + m_IntersectList.push_back(newNode); + + SwapPositionsInSEL(e, eNext); + isModified = true; + } + else + e = eNext; + } + if( e->PrevInSEL ) e->PrevInSEL->NextInSEL = 0; + else break; + } + while ( isModified ); + m_SortedEdges = 0; //important +} +//------------------------------------------------------------------------------ + + +void Clipper::ProcessIntersectList() +{ + for (size_t i = 0; i < m_IntersectList.size(); ++i) + { + IntersectNode* iNode = m_IntersectList[i]; + { + IntersectEdges( iNode->Edge1, iNode->Edge2, iNode->Pt); + SwapPositionsInAEL( iNode->Edge1 , iNode->Edge2 ); + } + delete iNode; + } + m_IntersectList.clear(); +} +//------------------------------------------------------------------------------ + +bool IntersectListSort(IntersectNode* node1, IntersectNode* node2) +{ + return node2->Pt.Y < node1->Pt.Y; +} +//------------------------------------------------------------------------------ + +inline bool EdgesAdjacent(const IntersectNode &inode) +{ + return (inode.Edge1->NextInSEL == inode.Edge2) || + (inode.Edge1->PrevInSEL == inode.Edge2); +} +//------------------------------------------------------------------------------ + +bool Clipper::FixupIntersectionOrder() +{ + //pre-condition: intersections are sorted Bottom-most first. + //Now it's crucial that intersections are made only between adjacent edges, + //so to ensure this the order of intersections may need adjusting ... + CopyAELToSEL(); + std::sort(m_IntersectList.begin(), m_IntersectList.end(), IntersectListSort); + size_t cnt = m_IntersectList.size(); + for (size_t i = 0; i < cnt; ++i) + { + if (!EdgesAdjacent(*m_IntersectList[i])) + { + size_t j = i + 1; + while (j < cnt && !EdgesAdjacent(*m_IntersectList[j])) j++; + if (j == cnt) return false; + std::swap(m_IntersectList[i], m_IntersectList[j]); + } + SwapPositionsInSEL(m_IntersectList[i]->Edge1, m_IntersectList[i]->Edge2); + } + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::DoMaxima(TEdge *e) +{ + TEdge* eMaxPair = GetMaximaPairEx(e); + if (!eMaxPair) + { + if (e->OutIdx >= 0) + AddOutPt(e, e->Top); + DeleteFromAEL(e); + return; + } + + TEdge* eNext = e->NextInAEL; + while(eNext && eNext != eMaxPair) + { + IntersectEdges(e, eNext, e->Top); + SwapPositionsInAEL(e, eNext); + eNext = e->NextInAEL; + } + + if(e->OutIdx == Unassigned && eMaxPair->OutIdx == Unassigned) + { + DeleteFromAEL(e); + DeleteFromAEL(eMaxPair); + } + else if( e->OutIdx >= 0 && eMaxPair->OutIdx >= 0 ) + { + if (e->OutIdx >= 0) AddLocalMaxPoly(e, eMaxPair, e->Top); + DeleteFromAEL(e); + DeleteFromAEL(eMaxPair); + } +#ifdef use_lines + else if (e->WindDelta == 0) + { + if (e->OutIdx >= 0) + { + AddOutPt(e, e->Top); + e->OutIdx = Unassigned; + } + DeleteFromAEL(e); + + if (eMaxPair->OutIdx >= 0) + { + AddOutPt(eMaxPair, e->Top); + eMaxPair->OutIdx = Unassigned; + } + DeleteFromAEL(eMaxPair); + } +#endif + else throw clipperException("DoMaxima error"); +} +//------------------------------------------------------------------------------ + +void Clipper::ProcessEdgesAtTopOfScanbeam(const cInt topY) +{ + TEdge* e = m_ActiveEdges; + while( e ) + { + //1. process maxima, treating them as if they're 'bent' horizontal edges, + // but exclude maxima with horizontal edges. nb: e can't be a horizontal. + bool IsMaximaEdge = IsMaxima(e, topY); + + if(IsMaximaEdge) + { + TEdge* eMaxPair = GetMaximaPairEx(e); + IsMaximaEdge = (!eMaxPair || !IsHorizontal(*eMaxPair)); + } + + if(IsMaximaEdge) + { + if (m_StrictSimple) m_Maxima.push_back(e->Top.X); + TEdge* ePrev = e->PrevInAEL; + DoMaxima(e); + if( !ePrev ) e = m_ActiveEdges; + else e = ePrev->NextInAEL; + } + else + { + //2. promote horizontal edges, otherwise update Curr.X and Curr.Y ... + if (IsIntermediate(e, topY) && IsHorizontal(*e->NextInLML)) + { + UpdateEdgeIntoAEL(e); + if (e->OutIdx >= 0) + AddOutPt(e, e->Bot); + AddEdgeToSEL(e); + } + else + { + e->Curr.X = TopX( *e, topY ); + e->Curr.Y = topY; + } + + //When StrictlySimple and 'e' is being touched by another edge, then + //make sure both edges have a vertex here ... + if (m_StrictSimple) + { + TEdge* ePrev = e->PrevInAEL; + if ((e->OutIdx >= 0) && (e->WindDelta != 0) && ePrev && (ePrev->OutIdx >= 0) && + (ePrev->Curr.X == e->Curr.X) && (ePrev->WindDelta != 0)) + { + IntPoint pt = e->Curr; +#ifdef use_xyz + SetZ(pt, *ePrev, *e); +#endif + OutPt* op = AddOutPt(ePrev, pt); + OutPt* op2 = AddOutPt(e, pt); + AddJoin(op, op2, pt); //StrictlySimple (type-3) join + } + } + + e = e->NextInAEL; + } + } + + //3. Process horizontals at the Top of the scanbeam ... + m_Maxima.sort(); + ProcessHorizontals(); + m_Maxima.clear(); + + //4. Promote intermediate vertices ... + e = m_ActiveEdges; + while(e) + { + if(IsIntermediate(e, topY)) + { + OutPt* op = 0; + if( e->OutIdx >= 0 ) + op = AddOutPt(e, e->Top); + UpdateEdgeIntoAEL(e); + + //if output polygons share an edge, they'll need joining later ... + TEdge* ePrev = e->PrevInAEL; + TEdge* eNext = e->NextInAEL; + if (ePrev && ePrev->Curr.X == e->Bot.X && + ePrev->Curr.Y == e->Bot.Y && op && + ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y && + SlopesEqual(e->Curr, e->Top, ePrev->Curr, ePrev->Top, m_UseFullRange) && + (e->WindDelta != 0) && (ePrev->WindDelta != 0)) + { + OutPt* op2 = AddOutPt(ePrev, e->Bot); + AddJoin(op, op2, e->Top); + } + else if (eNext && eNext->Curr.X == e->Bot.X && + eNext->Curr.Y == e->Bot.Y && op && + eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y && + SlopesEqual(e->Curr, e->Top, eNext->Curr, eNext->Top, m_UseFullRange) && + (e->WindDelta != 0) && (eNext->WindDelta != 0)) + { + OutPt* op2 = AddOutPt(eNext, e->Bot); + AddJoin(op, op2, e->Top); + } + } + e = e->NextInAEL; + } +} +//------------------------------------------------------------------------------ + +void Clipper::FixupOutPolyline(OutRec &outrec) +{ + OutPt *pp = outrec.Pts; + OutPt *lastPP = pp->Prev; + while (pp != lastPP) + { + pp = pp->Next; + if (pp->Pt == pp->Prev->Pt) + { + if (pp == lastPP) lastPP = pp->Prev; + OutPt *tmpPP = pp->Prev; + tmpPP->Next = pp->Next; + pp->Next->Prev = tmpPP; + delete pp; + pp = tmpPP; + } + } + + if (pp == pp->Prev) + { + DisposeOutPts(pp); + outrec.Pts = 0; + return; + } +} +//------------------------------------------------------------------------------ + +void Clipper::FixupOutPolygon(OutRec &outrec) +{ + //FixupOutPolygon() - removes duplicate points and simplifies consecutive + //parallel edges by removing the middle vertex. + OutPt *lastOK = 0; + outrec.BottomPt = 0; + OutPt *pp = outrec.Pts; + bool preserveCol = m_PreserveCollinear || m_StrictSimple; + + for (;;) + { + if (pp->Prev == pp || pp->Prev == pp->Next) + { + DisposeOutPts(pp); + outrec.Pts = 0; + return; + } + + //test for duplicate points and collinear edges ... + if ((pp->Pt == pp->Next->Pt) || (pp->Pt == pp->Prev->Pt) || + (SlopesEqual(pp->Prev->Pt, pp->Pt, pp->Next->Pt, m_UseFullRange) && + (!preserveCol || !Pt2IsBetweenPt1AndPt3(pp->Prev->Pt, pp->Pt, pp->Next->Pt)))) + { + lastOK = 0; + OutPt *tmp = pp; + pp->Prev->Next = pp->Next; + pp->Next->Prev = pp->Prev; + pp = pp->Prev; + delete tmp; + } + else if (pp == lastOK) break; + else + { + if (!lastOK) lastOK = pp; + pp = pp->Next; + } + } + outrec.Pts = pp; +} +//------------------------------------------------------------------------------ + +int PointCount(OutPt *Pts) +{ + if (!Pts) return 0; + int result = 0; + OutPt* p = Pts; + do + { + result++; + p = p->Next; + } + while (p != Pts); + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::BuildResult(Paths &polys) +{ + polys.reserve(m_PolyOuts.size()); + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + if (!m_PolyOuts[i]->Pts) continue; + Path pg; + OutPt* p = m_PolyOuts[i]->Pts->Prev; + int cnt = PointCount(p); + if (cnt < 2) continue; + pg.reserve(cnt); + for (int i = 0; i < cnt; ++i) + { + pg.push_back(p->Pt); + p = p->Prev; + } + polys.push_back(pg); + } +} +//------------------------------------------------------------------------------ + +void Clipper::BuildResult2(PolyTree& polytree) +{ + polytree.Clear(); + polytree.AllNodes.reserve(m_PolyOuts.size()); + //add each output polygon/contour to polytree ... + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) + { + OutRec* outRec = m_PolyOuts[i]; + int cnt = PointCount(outRec->Pts); + if ((outRec->IsOpen && cnt < 2) || (!outRec->IsOpen && cnt < 3)) continue; + FixHoleLinkage(*outRec); + PolyNode* pn = new PolyNode(); + //nb: polytree takes ownership of all the PolyNodes + polytree.AllNodes.push_back(pn); + outRec->PolyNd = pn; + pn->Parent = 0; + pn->Index = 0; + pn->Contour.reserve(cnt); + OutPt *op = outRec->Pts->Prev; + for (int j = 0; j < cnt; j++) + { + pn->Contour.push_back(op->Pt); + op = op->Prev; + } + } + + //fixup PolyNode links etc ... + polytree.Childs.reserve(m_PolyOuts.size()); + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) + { + OutRec* outRec = m_PolyOuts[i]; + if (!outRec->PolyNd) continue; + if (outRec->IsOpen) + { + outRec->PolyNd->m_IsOpen = true; + polytree.AddChild(*outRec->PolyNd); + } + else if (outRec->FirstLeft && outRec->FirstLeft->PolyNd) + outRec->FirstLeft->PolyNd->AddChild(*outRec->PolyNd); + else + polytree.AddChild(*outRec->PolyNd); + } +} +//------------------------------------------------------------------------------ + +void SwapIntersectNodes(IntersectNode &int1, IntersectNode &int2) +{ + //just swap the contents (because fIntersectNodes is a single-linked-list) + IntersectNode inode = int1; //gets a copy of Int1 + int1.Edge1 = int2.Edge1; + int1.Edge2 = int2.Edge2; + int1.Pt = int2.Pt; + int2.Edge1 = inode.Edge1; + int2.Edge2 = inode.Edge2; + int2.Pt = inode.Pt; +} +//------------------------------------------------------------------------------ + +inline bool E2InsertsBeforeE1(TEdge &e1, TEdge &e2) +{ + if (e2.Curr.X == e1.Curr.X) + { + if (e2.Top.Y > e1.Top.Y) + return e2.Top.X < TopX(e1, e2.Top.Y); + else return e1.Top.X > TopX(e2, e1.Top.Y); + } + else return e2.Curr.X < e1.Curr.X; +} +//------------------------------------------------------------------------------ + +bool GetOverlap(const cInt a1, const cInt a2, const cInt b1, const cInt b2, + cInt& Left, cInt& Right) +{ + if (a1 < a2) + { + if (b1 < b2) {Left = std::max(a1,b1); Right = std::min(a2,b2);} + else {Left = std::max(a1,b2); Right = std::min(a2,b1);} + } + else + { + if (b1 < b2) {Left = std::max(a2,b1); Right = std::min(a1,b2);} + else {Left = std::max(a2,b2); Right = std::min(a1,b1);} + } + return Left < Right; +} +//------------------------------------------------------------------------------ + +inline void UpdateOutPtIdxs(OutRec& outrec) +{ + OutPt* op = outrec.Pts; + do + { + op->Idx = outrec.Idx; + op = op->Prev; + } + while(op != outrec.Pts); +} +//------------------------------------------------------------------------------ + +void Clipper::InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge) +{ + if(!m_ActiveEdges) + { + edge->PrevInAEL = 0; + edge->NextInAEL = 0; + m_ActiveEdges = edge; + } + else if(!startEdge && E2InsertsBeforeE1(*m_ActiveEdges, *edge)) + { + edge->PrevInAEL = 0; + edge->NextInAEL = m_ActiveEdges; + m_ActiveEdges->PrevInAEL = edge; + m_ActiveEdges = edge; + } + else + { + if(!startEdge) startEdge = m_ActiveEdges; + while(startEdge->NextInAEL && + !E2InsertsBeforeE1(*startEdge->NextInAEL , *edge)) + startEdge = startEdge->NextInAEL; + edge->NextInAEL = startEdge->NextInAEL; + if(startEdge->NextInAEL) startEdge->NextInAEL->PrevInAEL = edge; + edge->PrevInAEL = startEdge; + startEdge->NextInAEL = edge; + } +} +//---------------------------------------------------------------------- + +OutPt* DupOutPt(OutPt* outPt, bool InsertAfter) +{ + OutPt* result = new OutPt; + result->Pt = outPt->Pt; + result->Idx = outPt->Idx; + if (InsertAfter) + { + result->Next = outPt->Next; + result->Prev = outPt; + outPt->Next->Prev = result; + outPt->Next = result; + } + else + { + result->Prev = outPt->Prev; + result->Next = outPt; + outPt->Prev->Next = result; + outPt->Prev = result; + } + return result; +} +//------------------------------------------------------------------------------ + +bool JoinHorz(OutPt* op1, OutPt* op1b, OutPt* op2, OutPt* op2b, + const IntPoint Pt, bool DiscardLeft) +{ + Direction Dir1 = (op1->Pt.X > op1b->Pt.X ? dRightToLeft : dLeftToRight); + Direction Dir2 = (op2->Pt.X > op2b->Pt.X ? dRightToLeft : dLeftToRight); + if (Dir1 == Dir2) return false; + + //When DiscardLeft, we want Op1b to be on the Left of Op1, otherwise we + //want Op1b to be on the Right. (And likewise with Op2 and Op2b.) + //So, to facilitate this while inserting Op1b and Op2b ... + //when DiscardLeft, make sure we're AT or RIGHT of Pt before adding Op1b, + //otherwise make sure we're AT or LEFT of Pt. (Likewise with Op2b.) + if (Dir1 == dLeftToRight) + { + while (op1->Next->Pt.X <= Pt.X && + op1->Next->Pt.X >= op1->Pt.X && op1->Next->Pt.Y == Pt.Y) + op1 = op1->Next; + if (DiscardLeft && (op1->Pt.X != Pt.X)) op1 = op1->Next; + op1b = DupOutPt(op1, !DiscardLeft); + if (op1b->Pt != Pt) + { + op1 = op1b; + op1->Pt = Pt; + op1b = DupOutPt(op1, !DiscardLeft); + } + } + else + { + while (op1->Next->Pt.X >= Pt.X && + op1->Next->Pt.X <= op1->Pt.X && op1->Next->Pt.Y == Pt.Y) + op1 = op1->Next; + if (!DiscardLeft && (op1->Pt.X != Pt.X)) op1 = op1->Next; + op1b = DupOutPt(op1, DiscardLeft); + if (op1b->Pt != Pt) + { + op1 = op1b; + op1->Pt = Pt; + op1b = DupOutPt(op1, DiscardLeft); + } + } + + if (Dir2 == dLeftToRight) + { + while (op2->Next->Pt.X <= Pt.X && + op2->Next->Pt.X >= op2->Pt.X && op2->Next->Pt.Y == Pt.Y) + op2 = op2->Next; + if (DiscardLeft && (op2->Pt.X != Pt.X)) op2 = op2->Next; + op2b = DupOutPt(op2, !DiscardLeft); + if (op2b->Pt != Pt) + { + op2 = op2b; + op2->Pt = Pt; + op2b = DupOutPt(op2, !DiscardLeft); + }; + } else + { + while (op2->Next->Pt.X >= Pt.X && + op2->Next->Pt.X <= op2->Pt.X && op2->Next->Pt.Y == Pt.Y) + op2 = op2->Next; + if (!DiscardLeft && (op2->Pt.X != Pt.X)) op2 = op2->Next; + op2b = DupOutPt(op2, DiscardLeft); + if (op2b->Pt != Pt) + { + op2 = op2b; + op2->Pt = Pt; + op2b = DupOutPt(op2, DiscardLeft); + }; + }; + + if ((Dir1 == dLeftToRight) == DiscardLeft) + { + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + } + else + { + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + } + return true; +} +//------------------------------------------------------------------------------ + +bool Clipper::JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2) +{ + OutPt *op1 = j->OutPt1, *op1b; + OutPt *op2 = j->OutPt2, *op2b; + + //There are 3 kinds of joins for output polygons ... + //1. Horizontal joins where Join.OutPt1 & Join.OutPt2 are vertices anywhere + //along (horizontal) collinear edges (& Join.OffPt is on the same horizontal). + //2. Non-horizontal joins where Join.OutPt1 & Join.OutPt2 are at the same + //location at the Bottom of the overlapping segment (& Join.OffPt is above). + //3. StrictSimple joins where edges touch but are not collinear and where + //Join.OutPt1, Join.OutPt2 & Join.OffPt all share the same point. + bool isHorizontal = (j->OutPt1->Pt.Y == j->OffPt.Y); + + if (isHorizontal && (j->OffPt == j->OutPt1->Pt) && + (j->OffPt == j->OutPt2->Pt)) + { + //Strictly Simple join ... + if (outRec1 != outRec2) return false; + op1b = j->OutPt1->Next; + while (op1b != op1 && (op1b->Pt == j->OffPt)) + op1b = op1b->Next; + bool reverse1 = (op1b->Pt.Y > j->OffPt.Y); + op2b = j->OutPt2->Next; + while (op2b != op2 && (op2b->Pt == j->OffPt)) + op2b = op2b->Next; + bool reverse2 = (op2b->Pt.Y > j->OffPt.Y); + if (reverse1 == reverse2) return false; + if (reverse1) + { + op1b = DupOutPt(op1, false); + op2b = DupOutPt(op2, true); + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } else + { + op1b = DupOutPt(op1, true); + op2b = DupOutPt(op2, false); + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } + } + else if (isHorizontal) + { + //treat horizontal joins differently to non-horizontal joins since with + //them we're not yet sure where the overlapping is. OutPt1.Pt & OutPt2.Pt + //may be anywhere along the horizontal edge. + op1b = op1; + while (op1->Prev->Pt.Y == op1->Pt.Y && op1->Prev != op1b && op1->Prev != op2) + op1 = op1->Prev; + while (op1b->Next->Pt.Y == op1b->Pt.Y && op1b->Next != op1 && op1b->Next != op2) + op1b = op1b->Next; + if (op1b->Next == op1 || op1b->Next == op2) return false; //a flat 'polygon' + + op2b = op2; + while (op2->Prev->Pt.Y == op2->Pt.Y && op2->Prev != op2b && op2->Prev != op1b) + op2 = op2->Prev; + while (op2b->Next->Pt.Y == op2b->Pt.Y && op2b->Next != op2 && op2b->Next != op1) + op2b = op2b->Next; + if (op2b->Next == op2 || op2b->Next == op1) return false; //a flat 'polygon' + + cInt Left, Right; + //Op1 --> Op1b & Op2 --> Op2b are the extremites of the horizontal edges + if (!GetOverlap(op1->Pt.X, op1b->Pt.X, op2->Pt.X, op2b->Pt.X, Left, Right)) + return false; + + //DiscardLeftSide: when overlapping edges are joined, a spike will created + //which needs to be cleaned up. However, we don't want Op1 or Op2 caught up + //on the discard Side as either may still be needed for other joins ... + IntPoint Pt; + bool DiscardLeftSide; + if (op1->Pt.X >= Left && op1->Pt.X <= Right) + { + Pt = op1->Pt; DiscardLeftSide = (op1->Pt.X > op1b->Pt.X); + } + else if (op2->Pt.X >= Left&& op2->Pt.X <= Right) + { + Pt = op2->Pt; DiscardLeftSide = (op2->Pt.X > op2b->Pt.X); + } + else if (op1b->Pt.X >= Left && op1b->Pt.X <= Right) + { + Pt = op1b->Pt; DiscardLeftSide = op1b->Pt.X > op1->Pt.X; + } + else + { + Pt = op2b->Pt; DiscardLeftSide = (op2b->Pt.X > op2->Pt.X); + } + j->OutPt1 = op1; j->OutPt2 = op2; + return JoinHorz(op1, op1b, op2, op2b, Pt, DiscardLeftSide); + } else + { + //nb: For non-horizontal joins ... + // 1. Jr.OutPt1.Pt.Y == Jr.OutPt2.Pt.Y + // 2. Jr.OutPt1.Pt > Jr.OffPt.Y + + //make sure the polygons are correctly oriented ... + op1b = op1->Next; + while ((op1b->Pt == op1->Pt) && (op1b != op1)) op1b = op1b->Next; + bool Reverse1 = ((op1b->Pt.Y > op1->Pt.Y) || + !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange)); + if (Reverse1) + { + op1b = op1->Prev; + while ((op1b->Pt == op1->Pt) && (op1b != op1)) op1b = op1b->Prev; + if ((op1b->Pt.Y > op1->Pt.Y) || + !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange)) return false; + }; + op2b = op2->Next; + while ((op2b->Pt == op2->Pt) && (op2b != op2))op2b = op2b->Next; + bool Reverse2 = ((op2b->Pt.Y > op2->Pt.Y) || + !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange)); + if (Reverse2) + { + op2b = op2->Prev; + while ((op2b->Pt == op2->Pt) && (op2b != op2)) op2b = op2b->Prev; + if ((op2b->Pt.Y > op2->Pt.Y) || + !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange)) return false; + } + + if ((op1b == op1) || (op2b == op2) || (op1b == op2b) || + ((outRec1 == outRec2) && (Reverse1 == Reverse2))) return false; + + if (Reverse1) + { + op1b = DupOutPt(op1, false); + op2b = DupOutPt(op2, true); + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } else + { + op1b = DupOutPt(op1, true); + op2b = DupOutPt(op2, false); + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } + } +} +//---------------------------------------------------------------------- + +static OutRec* ParseFirstLeft(OutRec* FirstLeft) +{ + while (FirstLeft && !FirstLeft->Pts) + FirstLeft = FirstLeft->FirstLeft; + return FirstLeft; +} +//------------------------------------------------------------------------------ + +void Clipper::FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec) +{ + //tests if NewOutRec contains the polygon before reassigning FirstLeft + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec* outRec = m_PolyOuts[i]; + OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (outRec->Pts && firstLeft == OldOutRec) + { + if (Poly2ContainsPoly1(outRec->Pts, NewOutRec->Pts)) + outRec->FirstLeft = NewOutRec; + } + } +} +//---------------------------------------------------------------------- + +void Clipper::FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec) +{ + //A polygon has split into two such that one is now the inner of the other. + //It's possible that these polygons now wrap around other polygons, so check + //every polygon that's also contained by OuterOutRec's FirstLeft container + //(including 0) to see if they've become inner to the new inner polygon ... + OutRec* orfl = OuterOutRec->FirstLeft; + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec* outRec = m_PolyOuts[i]; + + if (!outRec->Pts || outRec == OuterOutRec || outRec == InnerOutRec) + continue; + OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (firstLeft != orfl && firstLeft != InnerOutRec && firstLeft != OuterOutRec) + continue; + if (Poly2ContainsPoly1(outRec->Pts, InnerOutRec->Pts)) + outRec->FirstLeft = InnerOutRec; + else if (Poly2ContainsPoly1(outRec->Pts, OuterOutRec->Pts)) + outRec->FirstLeft = OuterOutRec; + else if (outRec->FirstLeft == InnerOutRec || outRec->FirstLeft == OuterOutRec) + outRec->FirstLeft = orfl; + } +} +//---------------------------------------------------------------------- +void Clipper::FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec) +{ + //reassigns FirstLeft WITHOUT testing if NewOutRec contains the polygon + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec* outRec = m_PolyOuts[i]; + OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (outRec->Pts && outRec->FirstLeft == OldOutRec) + outRec->FirstLeft = NewOutRec; + } +} +//---------------------------------------------------------------------- + +void Clipper::JoinCommonEdges() +{ + for (JoinList::size_type i = 0; i < m_Joins.size(); i++) + { + Join* join = m_Joins[i]; + + OutRec *outRec1 = GetOutRec(join->OutPt1->Idx); + OutRec *outRec2 = GetOutRec(join->OutPt2->Idx); + + if (!outRec1->Pts || !outRec2->Pts) continue; + if (outRec1->IsOpen || outRec2->IsOpen) continue; + + //get the polygon fragment with the correct hole state (FirstLeft) + //before calling JoinPoints() ... + OutRec *holeStateRec; + if (outRec1 == outRec2) holeStateRec = outRec1; + else if (OutRec1RightOfOutRec2(outRec1, outRec2)) holeStateRec = outRec2; + else if (OutRec1RightOfOutRec2(outRec2, outRec1)) holeStateRec = outRec1; + else holeStateRec = GetLowermostRec(outRec1, outRec2); + + if (!JoinPoints(join, outRec1, outRec2)) continue; + + if (outRec1 == outRec2) + { + //instead of joining two polygons, we've just created a new one by + //splitting one polygon into two. + outRec1->Pts = join->OutPt1; + outRec1->BottomPt = 0; + outRec2 = CreateOutRec(); + outRec2->Pts = join->OutPt2; + + //update all OutRec2.Pts Idx's ... + UpdateOutPtIdxs(*outRec2); + + if (Poly2ContainsPoly1(outRec2->Pts, outRec1->Pts)) + { + //outRec1 contains outRec2 ... + outRec2->IsHole = !outRec1->IsHole; + outRec2->FirstLeft = outRec1; + + if (m_UsingPolyTree) FixupFirstLefts2(outRec2, outRec1); + + if ((outRec2->IsHole ^ m_ReverseOutput) == (Area(*outRec2) > 0)) + ReversePolyPtLinks(outRec2->Pts); + + } else if (Poly2ContainsPoly1(outRec1->Pts, outRec2->Pts)) + { + //outRec2 contains outRec1 ... + outRec2->IsHole = outRec1->IsHole; + outRec1->IsHole = !outRec2->IsHole; + outRec2->FirstLeft = outRec1->FirstLeft; + outRec1->FirstLeft = outRec2; + + if (m_UsingPolyTree) FixupFirstLefts2(outRec1, outRec2); + + if ((outRec1->IsHole ^ m_ReverseOutput) == (Area(*outRec1) > 0)) + ReversePolyPtLinks(outRec1->Pts); + } + else + { + //the 2 polygons are completely separate ... + outRec2->IsHole = outRec1->IsHole; + outRec2->FirstLeft = outRec1->FirstLeft; + + //fixup FirstLeft pointers that may need reassigning to OutRec2 + if (m_UsingPolyTree) FixupFirstLefts1(outRec1, outRec2); + } + + } else + { + //joined 2 polygons together ... + + outRec2->Pts = 0; + outRec2->BottomPt = 0; + outRec2->Idx = outRec1->Idx; + + outRec1->IsHole = holeStateRec->IsHole; + if (holeStateRec == outRec2) + outRec1->FirstLeft = outRec2->FirstLeft; + outRec2->FirstLeft = outRec1; + + if (m_UsingPolyTree) FixupFirstLefts3(outRec2, outRec1); + } + } +} + +//------------------------------------------------------------------------------ +// ClipperOffset support functions ... +//------------------------------------------------------------------------------ + +DoublePoint GetUnitNormal(const IntPoint &pt1, const IntPoint &pt2) +{ + if(pt2.X == pt1.X && pt2.Y == pt1.Y) + return DoublePoint(0, 0); + + double Dx = (double)(pt2.X - pt1.X); + double dy = (double)(pt2.Y - pt1.Y); + double f = 1 *1.0/ std::sqrt( Dx*Dx + dy*dy ); + Dx *= f; + dy *= f; + return DoublePoint(dy, -Dx); +} + +//------------------------------------------------------------------------------ +// ClipperOffset class +//------------------------------------------------------------------------------ + +ClipperOffset::ClipperOffset(double miterLimit, double arcTolerance) +{ + this->MiterLimit = miterLimit; + this->ArcTolerance = arcTolerance; + m_lowest.X = -1; +} +//------------------------------------------------------------------------------ + +ClipperOffset::~ClipperOffset() +{ + Clear(); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Clear() +{ + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) + delete m_polyNodes.Childs[i]; + m_polyNodes.Childs.clear(); + m_lowest.X = -1; +} +//------------------------------------------------------------------------------ + +void ClipperOffset::AddPath(const Path& path, JoinType joinType, EndType endType) +{ + int highI = (int)path.size() - 1; + if (highI < 0) return; + PolyNode* newNode = new PolyNode(); + newNode->m_jointype = joinType; + newNode->m_endtype = endType; + + //strip duplicate points from path and also get index to the lowest point ... + if (endType == etClosedLine || endType == etClosedPolygon) + while (highI > 0 && path[0] == path[highI]) highI--; + newNode->Contour.reserve(highI + 1); + newNode->Contour.push_back(path[0]); + int j = 0, k = 0; + for (int i = 1; i <= highI; i++) + if (newNode->Contour[j] != path[i]) + { + j++; + newNode->Contour.push_back(path[i]); + if (path[i].Y > newNode->Contour[k].Y || + (path[i].Y == newNode->Contour[k].Y && + path[i].X < newNode->Contour[k].X)) k = j; + } + if (endType == etClosedPolygon && j < 2) + { + delete newNode; + return; + } + m_polyNodes.AddChild(*newNode); + + //if this path's lowest pt is lower than all the others then update m_lowest + if (endType != etClosedPolygon) return; + if (m_lowest.X < 0) + m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k); + else + { + IntPoint ip = m_polyNodes.Childs[(int)m_lowest.X]->Contour[(int)m_lowest.Y]; + if (newNode->Contour[k].Y > ip.Y || + (newNode->Contour[k].Y == ip.Y && + newNode->Contour[k].X < ip.X)) + m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::AddPaths(const Paths& paths, JoinType joinType, EndType endType) +{ + for (Paths::size_type i = 0; i < paths.size(); ++i) + AddPath(paths[i], joinType, endType); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::FixOrientations() +{ + //fixup orientations of all closed paths if the orientation of the + //closed path with the lowermost vertex is wrong ... + if (m_lowest.X >= 0 && + !Orientation(m_polyNodes.Childs[(int)m_lowest.X]->Contour)) + { + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) + { + PolyNode& node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedPolygon || + (node.m_endtype == etClosedLine && Orientation(node.Contour))) + ReversePath(node.Contour); + } + } else + { + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) + { + PolyNode& node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedLine && !Orientation(node.Contour)) + ReversePath(node.Contour); + } + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Execute(Paths& solution, double delta) +{ + solution.clear(); + FixOrientations(); + DoOffset(delta); + + //now clean up 'corners' ... + Clipper clpr; + clpr.AddPaths(m_destPolys, ptSubject, true); + if (delta > 0) + { + clpr.Execute(ctUnion, solution, pftPositive, pftPositive); + } + else + { + IntRect r = clpr.GetBounds(); + Path outer(4); + outer[0] = IntPoint(r.left - 10, r.bottom + 10); + outer[1] = IntPoint(r.right + 10, r.bottom + 10); + outer[2] = IntPoint(r.right + 10, r.top - 10); + outer[3] = IntPoint(r.left - 10, r.top - 10); + + clpr.AddPath(outer, ptSubject, true); + clpr.ReverseSolution(true); + clpr.Execute(ctUnion, solution, pftNegative, pftNegative); + if (solution.size() > 0) solution.erase(solution.begin()); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Execute(PolyTree& solution, double delta) +{ + solution.Clear(); + FixOrientations(); + DoOffset(delta); + + //now clean up 'corners' ... + Clipper clpr; + clpr.AddPaths(m_destPolys, ptSubject, true); + if (delta > 0) + { + clpr.Execute(ctUnion, solution, pftPositive, pftPositive); + } + else + { + IntRect r = clpr.GetBounds(); + Path outer(4); + outer[0] = IntPoint(r.left - 10, r.bottom + 10); + outer[1] = IntPoint(r.right + 10, r.bottom + 10); + outer[2] = IntPoint(r.right + 10, r.top - 10); + outer[3] = IntPoint(r.left - 10, r.top - 10); + + clpr.AddPath(outer, ptSubject, true); + clpr.ReverseSolution(true); + clpr.Execute(ctUnion, solution, pftNegative, pftNegative); + //remove the outer PolyNode rectangle ... + if (solution.ChildCount() == 1 && solution.Childs[0]->ChildCount() > 0) + { + PolyNode* outerNode = solution.Childs[0]; + solution.Childs.reserve(outerNode->ChildCount()); + solution.Childs[0] = outerNode->Childs[0]; + solution.Childs[0]->Parent = outerNode->Parent; + for (int i = 1; i < outerNode->ChildCount(); ++i) + solution.AddChild(*outerNode->Childs[i]); + } + else + solution.Clear(); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoOffset(double delta) +{ + m_destPolys.clear(); + m_delta = delta; + + //if Zero offset, just copy any CLOSED polygons to m_p and return ... + if (NEAR_ZERO(delta)) + { + m_destPolys.reserve(m_polyNodes.ChildCount()); + for (int i = 0; i < m_polyNodes.ChildCount(); i++) + { + PolyNode& node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedPolygon) + m_destPolys.push_back(node.Contour); + } + return; + } + + //see offset_triginometry3.svg in the documentation folder ... + if (MiterLimit > 2) m_miterLim = 2/(MiterLimit * MiterLimit); + else m_miterLim = 0.5; + + double y; + if (ArcTolerance <= 0.0) y = def_arc_tolerance; + else if (ArcTolerance > std::fabs(delta) * def_arc_tolerance) + y = std::fabs(delta) * def_arc_tolerance; + else y = ArcTolerance; + //see offset_triginometry2.svg in the documentation folder ... + double steps = pi / std::acos(1 - y / std::fabs(delta)); + if (steps > std::fabs(delta) * pi) + steps = std::fabs(delta) * pi; //ie excessive precision check + m_sin = std::sin(two_pi / steps); + m_cos = std::cos(two_pi / steps); + m_StepsPerRad = steps / two_pi; + if (delta < 0.0) m_sin = -m_sin; + + m_destPolys.reserve(m_polyNodes.ChildCount() * 2); + for (int i = 0; i < m_polyNodes.ChildCount(); i++) + { + PolyNode& node = *m_polyNodes.Childs[i]; + m_srcPoly = node.Contour; + + int len = (int)m_srcPoly.size(); + if (len == 0 || (delta <= 0 && (len < 3 || node.m_endtype != etClosedPolygon))) + continue; + + m_destPoly.clear(); + if (len == 1) + { + if (node.m_jointype == jtRound) + { + double X = 1.0, Y = 0.0; + for (cInt j = 1; j <= steps; j++) + { + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[0].X + X * delta), + Round(m_srcPoly[0].Y + Y * delta))); + double X2 = X; + X = X * m_cos - m_sin * Y; + Y = X2 * m_sin + Y * m_cos; + } + } + else + { + double X = -1.0, Y = -1.0; + for (int j = 0; j < 4; ++j) + { + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[0].X + X * delta), + Round(m_srcPoly[0].Y + Y * delta))); + if (X < 0) X = 1; + else if (Y < 0) Y = 1; + else X = -1; + } + } + m_destPolys.push_back(m_destPoly); + continue; + } + //build m_normals ... + m_normals.clear(); + m_normals.reserve(len); + for (int j = 0; j < len - 1; ++j) + m_normals.push_back(GetUnitNormal(m_srcPoly[j], m_srcPoly[j + 1])); + if (node.m_endtype == etClosedLine || node.m_endtype == etClosedPolygon) + m_normals.push_back(GetUnitNormal(m_srcPoly[len - 1], m_srcPoly[0])); + else + m_normals.push_back(DoublePoint(m_normals[len - 2])); + + if (node.m_endtype == etClosedPolygon) + { + int k = len - 1; + for (int j = 0; j < len; ++j) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + } + else if (node.m_endtype == etClosedLine) + { + int k = len - 1; + for (int j = 0; j < len; ++j) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + m_destPoly.clear(); + //re-build m_normals ... + DoublePoint n = m_normals[len -1]; + for (int j = len - 1; j > 0; j--) + m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y); + m_normals[0] = DoublePoint(-n.X, -n.Y); + k = 0; + for (int j = len - 1; j >= 0; j--) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + } + else + { + int k = 0; + for (int j = 1; j < len - 1; ++j) + OffsetPoint(j, k, node.m_jointype); + + IntPoint pt1; + if (node.m_endtype == etOpenButt) + { + int j = len - 1; + pt1 = IntPoint((cInt)Round(m_srcPoly[j].X + m_normals[j].X * + delta), (cInt)Round(m_srcPoly[j].Y + m_normals[j].Y * delta)); + m_destPoly.push_back(pt1); + pt1 = IntPoint((cInt)Round(m_srcPoly[j].X - m_normals[j].X * + delta), (cInt)Round(m_srcPoly[j].Y - m_normals[j].Y * delta)); + m_destPoly.push_back(pt1); + } + else + { + int j = len - 1; + k = len - 2; + m_sinA = 0; + m_normals[j] = DoublePoint(-m_normals[j].X, -m_normals[j].Y); + if (node.m_endtype == etOpenSquare) + DoSquare(j, k); + else + DoRound(j, k); + } + + //re-build m_normals ... + for (int j = len - 1; j > 0; j--) + m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y); + m_normals[0] = DoublePoint(-m_normals[1].X, -m_normals[1].Y); + + k = len - 1; + for (int j = k - 1; j > 0; --j) OffsetPoint(j, k, node.m_jointype); + + if (node.m_endtype == etOpenButt) + { + pt1 = IntPoint((cInt)Round(m_srcPoly[0].X - m_normals[0].X * delta), + (cInt)Round(m_srcPoly[0].Y - m_normals[0].Y * delta)); + m_destPoly.push_back(pt1); + pt1 = IntPoint((cInt)Round(m_srcPoly[0].X + m_normals[0].X * delta), + (cInt)Round(m_srcPoly[0].Y + m_normals[0].Y * delta)); + m_destPoly.push_back(pt1); + } + else + { + k = 1; + m_sinA = 0; + if (node.m_endtype == etOpenSquare) + DoSquare(0, 1); + else + DoRound(0, 1); + } + m_destPolys.push_back(m_destPoly); + } + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::OffsetPoint(int j, int& k, JoinType jointype) +{ + //cross product ... + m_sinA = (m_normals[k].X * m_normals[j].Y - m_normals[j].X * m_normals[k].Y); + if (std::fabs(m_sinA * m_delta) < 1.0) + { + //dot product ... + double cosA = (m_normals[k].X * m_normals[j].X + m_normals[j].Y * m_normals[k].Y ); + if (cosA > 0) // angle => 0 degrees + { + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta))); + return; + } + //else angle => 180 degrees + } + else if (m_sinA > 1.0) m_sinA = 1.0; + else if (m_sinA < -1.0) m_sinA = -1.0; + + if (m_sinA * m_delta < 0) + { + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta))); + m_destPoly.push_back(m_srcPoly[j]); + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[j].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta))); + } + else + switch (jointype) + { + case jtMiter: + { + double r = 1 + (m_normals[j].X * m_normals[k].X + + m_normals[j].Y * m_normals[k].Y); + if (r >= m_miterLim) DoMiter(j, k, r); else DoSquare(j, k); + break; + } + case jtSquare: DoSquare(j, k); break; + case jtRound: DoRound(j, k); break; + } + k = j; +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoSquare(int j, int k) +{ + double dx = std::tan(std::atan2(m_sinA, + m_normals[k].X * m_normals[j].X + m_normals[k].Y * m_normals[j].Y) / 4); + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_delta * (m_normals[k].X - m_normals[k].Y * dx)), + Round(m_srcPoly[j].Y + m_delta * (m_normals[k].Y + m_normals[k].X * dx)))); + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_delta * (m_normals[j].X + m_normals[j].Y * dx)), + Round(m_srcPoly[j].Y + m_delta * (m_normals[j].Y - m_normals[j].X * dx)))); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoMiter(int j, int k, double r) +{ + double q = m_delta / r; + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + (m_normals[k].X + m_normals[j].X) * q), + Round(m_srcPoly[j].Y + (m_normals[k].Y + m_normals[j].Y) * q))); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoRound(int j, int k) +{ + double a = std::atan2(m_sinA, + m_normals[k].X * m_normals[j].X + m_normals[k].Y * m_normals[j].Y); + int steps = std::max((int)Round(m_StepsPerRad * std::fabs(a)), 1); + + double X = m_normals[k].X, Y = m_normals[k].Y, X2; + for (int i = 0; i < steps; ++i) + { + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + X * m_delta), + Round(m_srcPoly[j].Y + Y * m_delta))); + X2 = X; + X = X * m_cos - m_sin * Y; + Y = X2 * m_sin + Y * m_cos; + } + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_normals[j].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta))); +} + +//------------------------------------------------------------------------------ +// Miscellaneous public functions +//------------------------------------------------------------------------------ + +void Clipper::DoSimplePolygons() +{ + PolyOutList::size_type i = 0; + while (i < m_PolyOuts.size()) + { + OutRec* outrec = m_PolyOuts[i++]; + OutPt* op = outrec->Pts; + if (!op || outrec->IsOpen) continue; + do //for each Pt in Polygon until duplicate found do ... + { + OutPt* op2 = op->Next; + while (op2 != outrec->Pts) + { + if ((op->Pt == op2->Pt) && op2->Next != op && op2->Prev != op) + { + //split the polygon into two ... + OutPt* op3 = op->Prev; + OutPt* op4 = op2->Prev; + op->Prev = op4; + op4->Next = op; + op2->Prev = op3; + op3->Next = op2; + + outrec->Pts = op; + OutRec* outrec2 = CreateOutRec(); + outrec2->Pts = op2; + UpdateOutPtIdxs(*outrec2); + if (Poly2ContainsPoly1(outrec2->Pts, outrec->Pts)) + { + //OutRec2 is contained by OutRec1 ... + outrec2->IsHole = !outrec->IsHole; + outrec2->FirstLeft = outrec; + if (m_UsingPolyTree) FixupFirstLefts2(outrec2, outrec); + } + else + if (Poly2ContainsPoly1(outrec->Pts, outrec2->Pts)) + { + //OutRec1 is contained by OutRec2 ... + outrec2->IsHole = outrec->IsHole; + outrec->IsHole = !outrec2->IsHole; + outrec2->FirstLeft = outrec->FirstLeft; + outrec->FirstLeft = outrec2; + if (m_UsingPolyTree) FixupFirstLefts2(outrec, outrec2); + } + else + { + //the 2 polygons are separate ... + outrec2->IsHole = outrec->IsHole; + outrec2->FirstLeft = outrec->FirstLeft; + if (m_UsingPolyTree) FixupFirstLefts1(outrec, outrec2); + } + op2 = op; //ie get ready for the Next iteration + } + op2 = op2->Next; + } + op = op->Next; + } + while (op != outrec->Pts); + } +} +//------------------------------------------------------------------------------ + +void ReversePath(Path& p) +{ + std::reverse(p.begin(), p.end()); +} +//------------------------------------------------------------------------------ + +void ReversePaths(Paths& p) +{ + for (Paths::size_type i = 0; i < p.size(); ++i) + ReversePath(p[i]); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType) +{ + Clipper c; + c.StrictlySimple(true); + c.AddPath(in_poly, ptSubject, true); + c.Execute(ctUnion, out_polys, fillType, fillType); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType) +{ + Clipper c; + c.StrictlySimple(true); + c.AddPaths(in_polys, ptSubject, true); + c.Execute(ctUnion, out_polys, fillType, fillType); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygons(Paths &polys, PolyFillType fillType) +{ + SimplifyPolygons(polys, polys, fillType); +} +//------------------------------------------------------------------------------ + +inline double DistanceSqrd(const IntPoint& pt1, const IntPoint& pt2) +{ + double Dx = ((double)pt1.X - pt2.X); + double dy = ((double)pt1.Y - pt2.Y); + return (Dx*Dx + dy*dy); +} +//------------------------------------------------------------------------------ + +double DistanceFromLineSqrd( + const IntPoint& pt, const IntPoint& ln1, const IntPoint& ln2) +{ + //The equation of a line in general form (Ax + By + C = 0) + //given 2 points (x�,y�) & (x�,y�) is ... + //(y� - y�)x + (x� - x�)y + (y� - y�)x� - (x� - x�)y� = 0 + //A = (y� - y�); B = (x� - x�); C = (y� - y�)x� - (x� - x�)y� + //perpendicular distance of point (x�,y�) = (Ax� + By� + C)/Sqrt(A� + B�) + //see http://en.wikipedia.org/wiki/Perpendicular_distance + double A = double(ln1.Y - ln2.Y); + double B = double(ln2.X - ln1.X); + double C = A * ln1.X + B * ln1.Y; + C = A * pt.X + B * pt.Y - C; + return (C * C) / (A * A + B * B); +} +//--------------------------------------------------------------------------- + +bool SlopesNearCollinear(const IntPoint& pt1, + const IntPoint& pt2, const IntPoint& pt3, double distSqrd) +{ + //this function is more accurate when the point that's geometrically + //between the other 2 points is the one that's tested for distance. + //ie makes it more likely to pick up 'spikes' ... + if (Abs(pt1.X - pt2.X) > Abs(pt1.Y - pt2.Y)) + { + if ((pt1.X > pt2.X) == (pt1.X < pt3.X)) + return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd; + else if ((pt2.X > pt1.X) == (pt2.X < pt3.X)) + return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd; + else + return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd; + } + else + { + if ((pt1.Y > pt2.Y) == (pt1.Y < pt3.Y)) + return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd; + else if ((pt2.Y > pt1.Y) == (pt2.Y < pt3.Y)) + return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd; + else + return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd; + } +} +//------------------------------------------------------------------------------ + +bool PointsAreClose(IntPoint pt1, IntPoint pt2, double distSqrd) +{ + double Dx = (double)pt1.X - pt2.X; + double dy = (double)pt1.Y - pt2.Y; + return ((Dx * Dx) + (dy * dy) <= distSqrd); +} +//------------------------------------------------------------------------------ + +OutPt* ExcludeOp(OutPt* op) +{ + OutPt* result = op->Prev; + result->Next = op->Next; + op->Next->Prev = result; + result->Idx = 0; + return result; +} +//------------------------------------------------------------------------------ + +void CleanPolygon(const Path& in_poly, Path& out_poly, double distance) +{ + //distance = proximity in units/pixels below which vertices + //will be stripped. Default ~= sqrt(2). + + size_t size = in_poly.size(); + + if (size == 0) + { + out_poly.clear(); + return; + } + + OutPt* outPts = new OutPt[size]; + for (size_t i = 0; i < size; ++i) + { + outPts[i].Pt = in_poly[i]; + outPts[i].Next = &outPts[(i + 1) % size]; + outPts[i].Next->Prev = &outPts[i]; + outPts[i].Idx = 0; + } + + double distSqrd = distance * distance; + OutPt* op = &outPts[0]; + while (op->Idx == 0 && op->Next != op->Prev) + { + if (PointsAreClose(op->Pt, op->Prev->Pt, distSqrd)) + { + op = ExcludeOp(op); + size--; + } + else if (PointsAreClose(op->Prev->Pt, op->Next->Pt, distSqrd)) + { + ExcludeOp(op->Next); + op = ExcludeOp(op); + size -= 2; + } + else if (SlopesNearCollinear(op->Prev->Pt, op->Pt, op->Next->Pt, distSqrd)) + { + op = ExcludeOp(op); + size--; + } + else + { + op->Idx = 1; + op = op->Next; + } + } + + if (size < 3) size = 0; + out_poly.resize(size); + for (size_t i = 0; i < size; ++i) + { + out_poly[i] = op->Pt; + op = op->Next; + } + delete [] outPts; +} +//------------------------------------------------------------------------------ + +void CleanPolygon(Path& poly, double distance) +{ + CleanPolygon(poly, poly, distance); +} +//------------------------------------------------------------------------------ + +void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance) +{ + out_polys.resize(in_polys.size()); + for (Paths::size_type i = 0; i < in_polys.size(); ++i) + CleanPolygon(in_polys[i], out_polys[i], distance); +} +//------------------------------------------------------------------------------ + +void CleanPolygons(Paths& polys, double distance) +{ + CleanPolygons(polys, polys, distance); +} +//------------------------------------------------------------------------------ + +void Minkowski(const Path& poly, const Path& path, + Paths& solution, bool isSum, bool isClosed) +{ + int delta = (isClosed ? 1 : 0); + size_t polyCnt = poly.size(); + size_t pathCnt = path.size(); + Paths pp; + pp.reserve(pathCnt); + if (isSum) + for (size_t i = 0; i < pathCnt; ++i) + { + Path p; + p.reserve(polyCnt); + for (size_t j = 0; j < poly.size(); ++j) + p.push_back(IntPoint(path[i].X + poly[j].X, path[i].Y + poly[j].Y)); + pp.push_back(p); + } + else + for (size_t i = 0; i < pathCnt; ++i) + { + Path p; + p.reserve(polyCnt); + for (size_t j = 0; j < poly.size(); ++j) + p.push_back(IntPoint(path[i].X - poly[j].X, path[i].Y - poly[j].Y)); + pp.push_back(p); + } + + solution.clear(); + solution.reserve((pathCnt + delta) * (polyCnt + 1)); + for (size_t i = 0; i < pathCnt - 1 + delta; ++i) + for (size_t j = 0; j < polyCnt; ++j) + { + Path quad; + quad.reserve(4); + quad.push_back(pp[i % pathCnt][j % polyCnt]); + quad.push_back(pp[(i + 1) % pathCnt][j % polyCnt]); + quad.push_back(pp[(i + 1) % pathCnt][(j + 1) % polyCnt]); + quad.push_back(pp[i % pathCnt][(j + 1) % polyCnt]); + if (!Orientation(quad)) ReversePath(quad); + solution.push_back(quad); + } +} +//------------------------------------------------------------------------------ + +void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed) +{ + Minkowski(pattern, path, solution, true, pathIsClosed); + Clipper c; + c.AddPaths(solution, ptSubject, true); + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +void TranslatePath(const Path& input, Path& output, const IntPoint delta) +{ + //precondition: input != output + output.resize(input.size()); + for (size_t i = 0; i < input.size(); ++i) + output[i] = IntPoint(input[i].X + delta.X, input[i].Y + delta.Y); +} +//------------------------------------------------------------------------------ + +void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed) +{ + Clipper c; + for (size_t i = 0; i < paths.size(); ++i) + { + Paths tmp; + Minkowski(pattern, paths[i], tmp, true, pathIsClosed); + c.AddPaths(tmp, ptSubject, true); + if (pathIsClosed) + { + Path tmp2; + TranslatePath(paths[i], tmp2, pattern[0]); + c.AddPath(tmp2, ptClip, true); + } + } + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution) +{ + Minkowski(poly1, poly2, solution, false, true); + Clipper c; + c.AddPaths(solution, ptSubject, true); + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +enum NodeType {ntAny, ntOpen, ntClosed}; + +void AddPolyNodeToPaths(const PolyNode& polynode, NodeType nodetype, Paths& paths) +{ + bool match = true; + if (nodetype == ntClosed) match = !polynode.IsOpen(); + else if (nodetype == ntOpen) return; + + if (!polynode.Contour.empty() && match) + paths.push_back(polynode.Contour); + for (int i = 0; i < polynode.ChildCount(); ++i) + AddPolyNodeToPaths(*polynode.Childs[i], nodetype, paths); +} +//------------------------------------------------------------------------------ + +void PolyTreeToPaths(const PolyTree& polytree, Paths& paths) +{ + paths.resize(0); + paths.reserve(polytree.Total()); + AddPolyNodeToPaths(polytree, ntAny, paths); +} +//------------------------------------------------------------------------------ + +void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths) +{ + paths.resize(0); + paths.reserve(polytree.Total()); + AddPolyNodeToPaths(polytree, ntClosed, paths); +} +//------------------------------------------------------------------------------ + +void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths) +{ + paths.resize(0); + paths.reserve(polytree.Total()); + //Open paths are top level only, so ... + for (int i = 0; i < polytree.ChildCount(); ++i) + if (polytree.Childs[i]->IsOpen()) + paths.push_back(polytree.Childs[i]->Contour); +} +//------------------------------------------------------------------------------ + +std::ostream& operator <<(std::ostream &s, const IntPoint &p) +{ + s << "(" << p.X << "," << p.Y << ")"; + return s; +} +//------------------------------------------------------------------------------ + +std::ostream& operator <<(std::ostream &s, const Path &p) +{ + if (p.empty()) return s; + Path::size_type last = p.size() -1; + for (Path::size_type i = 0; i < last; i++) + s << "(" << p[i].X << "," << p[i].Y << "), "; + s << "(" << p[last].X << "," << p[last].Y << ")\n"; + return s; +} +//------------------------------------------------------------------------------ + +std::ostream& operator <<(std::ostream &s, const Paths &p) +{ + for (Paths::size_type i = 0; i < p.size(); i++) + s << p[i]; + s << "\n"; + return s; +} +//------------------------------------------------------------------------------ + +} //ClipperLib namespace diff --git a/nemo-retriever-ocr/cpp/third_party/clipper/clipper.hpp b/nemo-retriever-ocr/cpp/third_party/clipper/clipper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9798ba57870659dd468e8d51e78f97d6e2f8b43e --- /dev/null +++ b/nemo-retriever-ocr/cpp/third_party/clipper/clipper.hpp @@ -0,0 +1,404 @@ +/******************************************************************************* +* * +* Author : Angus Johnson * +* Version : 6.4.0 * +* Date : 2 July 2015 * +* Website : http://www.angusj.com * +* Copyright : Angus Johnson 2010-2015 * +* * +* License: * +* Use, modification & distribution is subject to Boost Software License Ver 1. * +* http://www.boost.org/LICENSE_1_0.txt * +* * +* Attributions: * +* The code in this library is an extension of Bala Vatti's clipping algorithm: * +* "A generic solution to polygon clipping" * +* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. * +* http://portal.acm.org/citation.cfm?id=129906 * +* * +* Computer graphics and geometric modeling: implementation and algorithms * +* By Max K. Agoston * +* Springer; 1 edition (January 4, 2005) * +* http://books.google.com/books?q=vatti+clipping+agoston * +* * +* See also: * +* "Polygon Offsetting by Computing Winding Numbers" * +* Paper no. DETC2005-85513 pp. 565-575 * +* ASME 2005 International Design Engineering Technical Conferences * +* and Computers and Information in Engineering Conference (IDETC/CIE2005) * +* September 24-28, 2005 , Long Beach, California, USA * +* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf * +* * +*******************************************************************************/ + +#ifndef clipper_hpp +#define clipper_hpp + +#define CLIPPER_VERSION "6.2.6" + +//use_int32: When enabled 32bit ints are used instead of 64bit ints. This +//improve performance but coordinate values are limited to the range +/- 46340 +//#define use_int32 + +//use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance. +//#define use_xyz + +//use_lines: Enables line clipping. Adds a very minor cost to performance. +#define use_lines + +//use_deprecated: Enables temporary support for the obsolete functions +//#define use_deprecated + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ClipperLib { + +enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor }; +enum PolyType { ptSubject, ptClip }; +//By far the most widely used winding rules for polygon filling are +//EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32) +//Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL) +//see http://glprogramming.com/red/chapter11.html +enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative }; + +#ifdef use_int32 + // typedef int cInt; + typedef float cInt; + static cInt const loRange = 0x7FFF; + static cInt const hiRange = 0x7FFF; +#else + // typedef signed long long cInt; + typedef double cInt; + static cInt const loRange = 0x3FFFFFFF; + static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL; + typedef signed long long long64; //used by Int128 class + typedef unsigned long long ulong64; + +#endif + +struct IntPoint { + cInt X; + cInt Y; +#ifdef use_xyz + cInt Z; + IntPoint(cInt x = 0, cInt y = 0, cInt z = 0): X(x), Y(y), Z(z) {}; +#else + IntPoint(cInt x = 0, cInt y = 0): X(x), Y(y) {}; +#endif + + friend inline bool operator== (const IntPoint& a, const IntPoint& b) + { + return a.X == b.X && a.Y == b.Y; + } + friend inline bool operator!= (const IntPoint& a, const IntPoint& b) + { + return a.X != b.X || a.Y != b.Y; + } +}; +//------------------------------------------------------------------------------ + +typedef std::vector< IntPoint > Path; +typedef std::vector< Path > Paths; + +inline Path& operator <<(Path& poly, const IntPoint& p) {poly.push_back(p); return poly;} +inline Paths& operator <<(Paths& polys, const Path& p) {polys.push_back(p); return polys;} + +std::ostream& operator <<(std::ostream &s, const IntPoint &p); +std::ostream& operator <<(std::ostream &s, const Path &p); +std::ostream& operator <<(std::ostream &s, const Paths &p); + +struct DoublePoint +{ + double X; + double Y; + DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {} + DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {} +}; +//------------------------------------------------------------------------------ + +#ifdef use_xyz +typedef void (*ZFillCallback)(IntPoint& e1bot, IntPoint& e1top, IntPoint& e2bot, IntPoint& e2top, IntPoint& pt); +#endif + +enum InitOptions {ioReverseSolution = 1, ioStrictlySimple = 2, ioPreserveCollinear = 4}; +enum JoinType {jtSquare, jtRound, jtMiter}; +enum EndType {etClosedPolygon, etClosedLine, etOpenButt, etOpenSquare, etOpenRound}; + +class PolyNode; +typedef std::vector< PolyNode* > PolyNodes; + +class PolyNode +{ +public: + PolyNode(); + virtual ~PolyNode(){}; + Path Contour; + PolyNodes Childs; + PolyNode* Parent; + PolyNode* GetNext() const; + bool IsHole() const; + bool IsOpen() const; + int ChildCount() const; +private: + unsigned Index; //node index in Parent.Childs + bool m_IsOpen; + JoinType m_jointype; + EndType m_endtype; + PolyNode* GetNextSiblingUp() const; + void AddChild(PolyNode& child); + friend class Clipper; //to access Index + friend class ClipperOffset; +}; + +class PolyTree: public PolyNode +{ +public: + ~PolyTree(){Clear();}; + PolyNode* GetFirst() const; + void Clear(); + int Total() const; +private: + PolyNodes AllNodes; + friend class Clipper; //to access AllNodes +}; + +bool Orientation(const Path &poly); +double Area(const Path &poly); +int PointInPolygon(const IntPoint &pt, const Path &path); + +void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType = pftEvenOdd); +void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType = pftEvenOdd); +void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd); + +void CleanPolygon(const Path& in_poly, Path& out_poly, double distance = 1.415); +void CleanPolygon(Path& poly, double distance = 1.415); +void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance = 1.415); +void CleanPolygons(Paths& polys, double distance = 1.415); + +void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed); +void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed); +void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution); + +void PolyTreeToPaths(const PolyTree& polytree, Paths& paths); +void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths); +void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths); + +void ReversePath(Path& p); +void ReversePaths(Paths& p); + +struct IntRect { cInt left; cInt top; cInt right; cInt bottom; }; + +//enums that are used internally ... +enum EdgeSide { esLeft = 1, esRight = 2}; + +//forward declarations (for stuff used internally) ... +struct TEdge; +struct IntersectNode; +struct LocalMinimum; +struct OutPt; +struct OutRec; +struct Join; + +typedef std::vector < OutRec* > PolyOutList; +typedef std::vector < TEdge* > EdgeList; +typedef std::vector < Join* > JoinList; +typedef std::vector < IntersectNode* > IntersectList; + +//------------------------------------------------------------------------------ + +//ClipperBase is the ancestor to the Clipper class. It should not be +//instantiated directly. This class simply abstracts the conversion of sets of +//polygon coordinates into edge objects that are stored in a LocalMinima list. +class ClipperBase +{ +public: + ClipperBase(); + virtual ~ClipperBase(); + virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed); + bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed); + virtual void Clear(); + IntRect GetBounds(); + bool PreserveCollinear() {return m_PreserveCollinear;}; + void PreserveCollinear(bool value) {m_PreserveCollinear = value;}; +protected: + void DisposeLocalMinimaList(); + TEdge* AddBoundsToLML(TEdge *e, bool IsClosed); + virtual void Reset(); + TEdge* ProcessBound(TEdge* E, bool IsClockwise); + void InsertScanbeam(const cInt Y); + bool PopScanbeam(cInt &Y); + bool LocalMinimaPending(); + bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin); + OutRec* CreateOutRec(); + void DisposeAllOutRecs(); + void DisposeOutRec(PolyOutList::size_type index); + void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2); + void DeleteFromAEL(TEdge *e); + void UpdateEdgeIntoAEL(TEdge *&e); + + typedef std::vector MinimaList; + MinimaList::iterator m_CurrentLM; + MinimaList m_MinimaList; + + bool m_UseFullRange; + EdgeList m_edges; + bool m_PreserveCollinear; + bool m_HasOpenPaths; + PolyOutList m_PolyOuts; + TEdge *m_ActiveEdges; + + typedef std::priority_queue ScanbeamList; + ScanbeamList m_Scanbeam; +}; +//------------------------------------------------------------------------------ + +class Clipper : public virtual ClipperBase +{ +public: + Clipper(int initOptions = 0); + bool Execute(ClipType clipType, + Paths &solution, + PolyFillType fillType = pftEvenOdd); + bool Execute(ClipType clipType, + Paths &solution, + PolyFillType subjFillType, + PolyFillType clipFillType); + bool Execute(ClipType clipType, + PolyTree &polytree, + PolyFillType fillType = pftEvenOdd); + bool Execute(ClipType clipType, + PolyTree &polytree, + PolyFillType subjFillType, + PolyFillType clipFillType); + bool ReverseSolution() { return m_ReverseOutput; }; + void ReverseSolution(bool value) {m_ReverseOutput = value;}; + bool StrictlySimple() {return m_StrictSimple;}; + void StrictlySimple(bool value) {m_StrictSimple = value;}; + //set the callback function for z value filling on intersections (otherwise Z is 0) +#ifdef use_xyz + void ZFillFunction(ZFillCallback zFillFunc); +#endif +protected: + virtual bool ExecuteInternal(); +private: + JoinList m_Joins; + JoinList m_GhostJoins; + IntersectList m_IntersectList; + ClipType m_ClipType; + typedef std::list MaximaList; + MaximaList m_Maxima; + TEdge *m_SortedEdges; + bool m_ExecuteLocked; + PolyFillType m_ClipFillType; + PolyFillType m_SubjFillType; + bool m_ReverseOutput; + bool m_UsingPolyTree; + bool m_StrictSimple; +#ifdef use_xyz + ZFillCallback m_ZFill; //custom callback +#endif + void SetWindingCount(TEdge& edge); + bool IsEvenOddFillType(const TEdge& edge) const; + bool IsEvenOddAltFillType(const TEdge& edge) const; + void InsertLocalMinimaIntoAEL(const cInt botY); + void InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge); + void AddEdgeToSEL(TEdge *edge); + bool PopEdgeFromSEL(TEdge *&edge); + void CopyAELToSEL(); + void DeleteFromSEL(TEdge *e); + void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2); + bool IsContributing(const TEdge& edge) const; + bool IsTopHorz(const cInt XPos); + void DoMaxima(TEdge *e); + void ProcessHorizontals(); + void ProcessHorizontal(TEdge *horzEdge); + void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); + OutPt* AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); + OutRec* GetOutRec(int idx); + void AppendPolygon(TEdge *e1, TEdge *e2); + void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt); + OutPt* AddOutPt(TEdge *e, const IntPoint &pt); + OutPt* GetLastOutPt(TEdge *e); + bool ProcessIntersections(const cInt topY); + void BuildIntersectList(const cInt topY); + void ProcessIntersectList(); + void ProcessEdgesAtTopOfScanbeam(const cInt topY); + void BuildResult(Paths& polys); + void BuildResult2(PolyTree& polytree); + void SetHoleState(TEdge *e, OutRec *outrec); + void DisposeIntersectNodes(); + bool FixupIntersectionOrder(); + void FixupOutPolygon(OutRec &outrec); + void FixupOutPolyline(OutRec &outrec); + bool IsHole(TEdge *e); + bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl); + void FixHoleLinkage(OutRec &outrec); + void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt); + void ClearJoins(); + void ClearGhostJoins(); + void AddGhostJoin(OutPt *op, const IntPoint offPt); + bool JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2); + void JoinCommonEdges(); + void DoSimplePolygons(); + void FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec); + void FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec); + void FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec); +#ifdef use_xyz + void SetZ(IntPoint& pt, TEdge& e1, TEdge& e2); +#endif +}; +//------------------------------------------------------------------------------ + +class ClipperOffset +{ +public: + ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25); + ~ClipperOffset(); + void AddPath(const Path& path, JoinType joinType, EndType endType); + void AddPaths(const Paths& paths, JoinType joinType, EndType endType); + void Execute(Paths& solution, double delta); + void Execute(PolyTree& solution, double delta); + void Clear(); + double MiterLimit; + double ArcTolerance; +private: + Paths m_destPolys; + Path m_srcPoly; + Path m_destPoly; + std::vector m_normals; + double m_delta, m_sinA, m_sin, m_cos; + double m_miterLim, m_StepsPerRad; + IntPoint m_lowest; + PolyNode m_polyNodes; + + void FixOrientations(); + void DoOffset(double delta); + void OffsetPoint(int j, int& k, JoinType jointype); + void DoSquare(int j, int k); + void DoMiter(int j, int k, double r); + void DoRound(int j, int k); +}; +//------------------------------------------------------------------------------ + +class clipperException : public std::exception +{ + public: + clipperException(const char* description): m_descr(description) {} + virtual ~clipperException() throw() {} + virtual const char* what() const throw() {return m_descr.c_str();} + private: + std::string m_descr; +}; +//------------------------------------------------------------------------------ + +} //ClipperLib namespace + +#endif //clipper_hpp diff --git a/nemo-retriever-ocr/cpp/trove/LICENSE b/nemo-retriever-ocr/cpp/trove/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d2d70accade4b814d6a54d4bc9b9001ddf6ee8c5 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/nemo-retriever-ocr/cpp/trove/README.md b/nemo-retriever-ocr/cpp/trove/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa50b1e6e4c79b1d4819a7b5f81dbcbfc59108cc --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/README.md @@ -0,0 +1,138 @@ +Trove +===== + +Trove is a CUDA library that provides efficient vector loads and +stores. It works for CUDA architectures 3.0 and above, and uses no +CUDA shared memory, making it easy to integrate. It is useful for +working with data in an Array of Structures format, and also when +writing code that consumes or produces an array of data per CUDA +thread. + +How it Works +============ + +This functionality is built out of a transposition routine that uses +the warp shuffle intrinsic to redistribute data amongst threads in the +CUDA warp. For example, when every thread in +the warp is loading contiguous structures from an array, the threads +collaboratively load all the data needed by the warp, using coalesced +memory accesses, then transpose the data to redistribute it to the +correct thread. + +The following cartoon illustrates how this works, for a notional warp +with eight threads. + +![Transpose](https://raw.github.com/BryanCatanzaro/trove/master/doc/transpose.png) + +Performance +=========== + +Accesses to arrays of structures can be 6X faster than direct memory +accesses using compiler generated loads and stores. The following +benchmarks were taken on a Tesla K20c. The structure `T` being loaded +and stored in these benchmarks is made of one to sixteen 32-bit integers. + +![Contiguous](https://raw.github.com/BryanCatanzaro/trove/master/doc/contiguous.png) +![Random](https://raw.github.com/BryanCatanzaro/trove/master/doc/random.png) + +High-level Interface +==================== + +```c++ +#include + +template +__global__ void +trove_gather(const int length, const int* indices, + trove::coalesced_ptr src, //Wrapped pointer + trove::coalesced_ptr dst) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + if (global_index < length) { + int index = indices[global_index]; + T data = src[index]; + dst[global_index] = data; + } +} +``` + +The high-level interface allows you to load and store structures +directly. Just wrap the pointers of your arrays in +`trove::coalesced_ptr`. You don't need to worry if the warp is +converged, or if the addresses you're accessing are contiguous. This +interface loses some performance, since it has to dynamically check +whether the warp is converged, and also broadcast all pointers from +all threads in each warp to all other threads in the warp, but it is +simple to use. + + +There are two restrictions on `T`: + + * `sizeof(T)` must be divisible by 4 + * `T` must have a default constructor + +Block Interface +=============== + +It's common for CUDA code to process or produce several values per thread. For +example, a merge operation may process 7 values per thread, to increase the +amount of serial work. For these cases, we provide a blocked interface +that enables efficient block-wise vector loads and stores. + +This interface relies on an array type `trove::array`, where `T` +is the type of each element of the array, and `s` is an integer that +statically determines the length of the array. `trove::array` types +can be converted to and from standard C arrays (see +[array.h](http://github.com/BryanCatanzaro/trove/blob/master/trove/array.h) +), but they have value +semantics rather than reference semantics, and they can only be +indexed statically. + +With this interface, each thread is assumed to be reading or writing +from contiguous locations in the input or output array. The user is +responsible for checking for convergence, which they probably do +anyway for functional reasons. If the warp is not converged, we +provide fallback functions that load and store arrays using compiler +generated code. + +```c++ +#include + +template +__global__ void test_block_copy(const T* x, T* r, int l) { + typedef trove::array s_ary; + int global_index = threadIdx.x + blockIdx.x * blockDim.x; + + for(int index = global_index; index < l; index += gridDim.x * blockDim.x) { + + //The block memory accesses only function + //correctly if the warp is converged. Here we check. + if (trove::warp_converged()) { + //Warp converged, indices are contiguous, call the fast + //load and store + s_ary d = trove::load_array_warp_contiguous(x, index); + trove::store_array_warp_contiguous(r, index, d); + } else { + //Warp not converged, call the slow load and store + s_ary d = trove::load_array(x, index); + trove::store_array(r, index, d); + } + } +} +``` + +Low-level Interface +=================== + +If you know your warp is converged, and that your threads are all +accessing contiguous locations in your array, you can use the +low-level interface for maximum performance. By contiguous, we mean +that if threads with indices *i* and thread *j* are in the same warp, +the pointers *pi* and *pj* you pass to the library must obey the +relation *pj* - *pi* == *j* - *i*. The low-level interface has the +following functions in ``: + +`template __device__ T load_warp_contiguous(const T* +src);` + +`template __device__ void store_warp_contiguous(const T& +data, T* dest);` diff --git a/nemo-retriever-ocr/cpp/trove/doc/contiguous.png b/nemo-retriever-ocr/cpp/trove/doc/contiguous.png new file mode 100644 index 0000000000000000000000000000000000000000..c91e23078027bc406c3e6cc48f3204dbb77fef61 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/doc/contiguous.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d9649776fec8fd812468d82d1a733c87fce86b302b9a52d5e70ba4ed5a1f76c +size 14320 diff --git a/nemo-retriever-ocr/cpp/trove/doc/random.png b/nemo-retriever-ocr/cpp/trove/doc/random.png new file mode 100644 index 0000000000000000000000000000000000000000..6d8a06849b643be4b6bac655be0cc2003796642a --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/doc/random.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:333c776116f1f580d76bce114e31e6b0293008973a8935968b5bf9636d5d6d25 +size 12801 diff --git a/nemo-retriever-ocr/cpp/trove/doc/transpose.png b/nemo-retriever-ocr/cpp/trove/doc/transpose.png new file mode 100644 index 0000000000000000000000000000000000000000..2ea07b14c5ca7ebfb241fb0c210ed9dbe96b45e4 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/doc/transpose.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0af509a49d5e13153409c5302146bebb27a9b2680e7dd06a817ec66f90128adc +size 69859 diff --git a/nemo-retriever-ocr/cpp/trove/tests/Makefile b/nemo-retriever-ocr/cpp/trove/tests/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..103a3626064e119b8a7bfa206386c9a76bbe52ec --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/tests/Makefile @@ -0,0 +1,14 @@ +ARCH ?= sm_70 +LOWER_BOUND ?= 1 +UPPER_BOUND ?= 16 + +CUDA_HOME ?= /usr/local/cuda +NVCC = $(CUDA_HOME)/bin/nvcc + +all: benchmark block + +benchmark: benchmark.cu + $(NVCC) -arch=$(ARCH) -I../ -Xptxas -v benchmark.cu -o benchmark -DLOWER_BOUND=$(LOWER_BOUND) -DUPPER_BOUND=$(UPPER_BOUND) + +block: block.cu + $(NVCC) -arch=$(ARCH) -I../ -Xptxas -v block.cu -o block diff --git a/nemo-retriever-ocr/cpp/trove/tests/benchmark.cu b/nemo-retriever-ocr/cpp/trove/tests/benchmark.cu new file mode 100644 index 0000000000000000000000000000000000000000..24acecded08da376282201bc3beeb47d5b067685 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/tests/benchmark.cu @@ -0,0 +1,366 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + + +#include +#include +#include +#include +#include +#include + +#include +#include + + + +#include +#include +#include "timer.h" + +using namespace trove; + +template +__global__ void +benchmark_contiguous_shfl_store(T* r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + T data; + int size = detail::aliased_size::value; + data = counting_array::impl( + global_index * size); + store_warp_contiguous(data, r + global_index); +} + +template +__global__ void +benchmark_contiguous_direct_store(T* r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + T data; + int size = detail::aliased_size::value; + data = counting_array::impl( + global_index * size); + r[global_index] = data; +} + +template +__global__ void +benchmark_contiguous_shfl_load(T* s, typename T::value_type* r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + T data = load_warp_contiguous(s + global_index); + r[global_index] = sum(data); +} + +template +__global__ void +benchmark_contiguous_direct_load(T* s, typename T::value_type* r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + T data = s[global_index]; + r[global_index] = sum(data); +} + +template +__global__ void +benchmark_shfl_gather(const int* indices, T* raw_s, T* raw_r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + int index = indices[global_index]; + trove::coalesced_ptr s(raw_s); + trove::coalesced_ptr r(raw_r); + T data = s[index]; + r[global_index] = data; +} + +template +__global__ void +benchmark_shfl_scatter(const int* indices, T* raw_s, T* raw_r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + int index = indices[global_index]; + trove::coalesced_ptr s(raw_s); + trove::coalesced_ptr r(raw_r); + T data = s[global_index]; + r[index] = data; +} + +template +__global__ void +benchmark_direct_gather(const int* indices, T* s, T* r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + int index = indices[global_index]; + T data = s[index]; + r[global_index] = data; +} + +template +__global__ void +benchmark_direct_scatter(const int* indices, T* s, T* r) { + int global_index = threadIdx.x + blockDim.x * blockIdx.x; + int index = indices[global_index]; + T data = s[global_index]; + r[index] = data; +} + +template +void run_benchmark_contiguous_store(const std::string name, void (*test)(array*), + void (*gold)(array*)) { + typedef array T; + + std::cout << name << ", " << i << ", "; + int n_blocks = 80 * 8 * 100; + int block_size = 256; + int n = n_blocks * block_size - 100; + thrust::device_vector r(n); + int iterations = 10; + cuda_timer timer; + timer.start(); + for(int j = 0; j < iterations; j++) { + test<<>>(thrust::raw_pointer_cast(r.data())); + } + float time = timer.stop(); + float gbs = (float)(sizeof(T) * (iterations * n_blocks * block_size)) / (time * 1000000); + std::cout << gbs << ", "; + bool correct = true; + if (test != gold) { + thrust::device_vector g(n); + gold<<>>(thrust::raw_pointer_cast(g.data())); + correct = thrust::equal(r.begin(), r.end(), g.begin()); + } + if (correct) + std::cout << "Results passed"; + else + std::cout << "INCORRECT"; + std::cout << std::endl; + +} + +template +struct run_benchmark_contiguous_shfl_store { + typedef array T; + static void impl() { + run_benchmark_contiguous_store("Contiguous SHFL Store", &benchmark_contiguous_shfl_store, + &benchmark_contiguous_direct_store); + } +}; + +template +struct run_benchmark_contiguous_direct_store { + typedef array T; + static void impl() { + run_benchmark_contiguous_store("Contiguous Direct Store", &benchmark_contiguous_direct_store, + &benchmark_contiguous_direct_store); + } +}; + + + +template +void fill_test(thrust::device_vector& d) { + thrust::device_ptr p = thrust::device_ptr((int*)thrust::raw_pointer_cast(d.data())); + thrust::counting_iterator c(0); + int s = d.size() * sizeof(T) / sizeof(int); + thrust::copy(c, c+s, p); +} + +template +void run_benchmark_contiguous_load(const std::string name, void (*test)(array*, int*), + void (*gold)(array*, int*)) { + typedef array T; + + std::cout << name << ", " << i << ", "; + int n_blocks = 80 * 8 * 100; + int block_size = 256; + int n = n_blocks * block_size; + thrust::device_vector s(n); + fill_test(s); + thrust::device_vector r(n); + int iterations = 10; + cuda_timer timer; + timer.start(); + for(int j = 0; j < iterations; j++) { + test<<>>(thrust::raw_pointer_cast(s.data()), thrust::raw_pointer_cast(r.data())); + } + float time = timer.stop(); + float gbs = (float)((sizeof(T) + sizeof(int)) * (iterations * n_blocks * block_size)) / (time * 1000000); + std::cout << gbs << ", "; + bool correct = true; + if (test != gold) { + thrust::device_vector g(n); + gold<<>>(thrust::raw_pointer_cast(s.data()), thrust::raw_pointer_cast(g.data())); + correct = thrust::equal(r.begin(), r.end(), g.begin()); + } + + if (correct) + std::cout << "Results passed"; + else + std::cout << "INCORRECT"; + std::cout << std::endl; + + +} + +template +struct run_benchmark_contiguous_shfl_load { + typedef array T; + static void impl() { + run_benchmark_contiguous_load("Contiguous SHFL Load", &benchmark_contiguous_shfl_load, &benchmark_contiguous_direct_load); + } +}; + +template +struct run_benchmark_contiguous_direct_load { + typedef array T; + static void impl() { + run_benchmark_contiguous_load("Contiguous Direct Load", &benchmark_contiguous_direct_load, &benchmark_contiguous_direct_load); + } +}; + +thrust::device_vector make_device_random(int s) { + thrust::host_vector h(s); + thrust::generate(h.begin(), h.end(), rand); + thrust::device_vector d = h; + return d; +} + +thrust::device_vector make_random_permutation(int s) { + thrust::device_vector keys = make_device_random(s); + thrust::counting_iterator c(0); + thrust::device_vector values(s); + thrust::copy(c, c+s, values.begin()); + thrust::sort_by_key(keys.begin(), keys.end(), values.begin()); + return values; +} + +template +void run_benchmark_random(const std::string name, const thrust::device_vector& permutation, + void (*test)(const int*, array*, array*), + void (*gold)(const int*, array*, array*)) { + typedef array T; + + std::cout << name << ", " << i << ", "; + int n_blocks = 80 * 8 * 100; + int block_size = 256; + int n = n_blocks * block_size; + thrust::device_vector s(n); + fill_test(s); + thrust::device_vector r(n); + int iterations = 10; + cuda_timer timer; + timer.start(); + for(int j = 0; j < iterations; j++) { + test<<>>( + thrust::raw_pointer_cast(permutation.data()), + thrust::raw_pointer_cast(s.data()), + thrust::raw_pointer_cast(r.data())); + } + float time = timer.stop(); + float gbs = (float)(sizeof(T) * (2 * iterations * n_blocks * block_size) + sizeof(int) * iterations * n_blocks * block_size) / (time * 1000000); + std::cout << gbs << ", "; + bool correct = true; + if (test != gold) { + thrust::device_vector g(n); + gold<<>>(thrust::raw_pointer_cast(permutation.data()), + thrust::raw_pointer_cast(s.data()), thrust::raw_pointer_cast(g.data())); + correct = thrust::equal(r.begin(), r.end(), g.begin()); + } + if (correct) + std::cout << "Results passed"; + else + std::cout << "INCORRECT"; + std::cout << std::endl; +} + +template +struct run_benchmark_shfl_gather { + typedef array T; + static void impl(const thrust::device_vector& permutation) { + run_benchmark_random("SHFL Gather", permutation, &benchmark_shfl_gather, &benchmark_direct_gather); + } +}; + +template +struct run_benchmark_direct_gather { + typedef array T; + static void impl(const thrust::device_vector& permutation) { + run_benchmark_random("Direct Gather", permutation, &benchmark_direct_gather, &benchmark_direct_gather); + } +}; + +template +struct run_benchmark_shfl_scatter { + typedef array T; + static void impl(const thrust::device_vector& permutation) { + run_benchmark_random("SHFL Scatter", permutation, &benchmark_shfl_scatter, &benchmark_direct_scatter); + } +}; +template +struct run_benchmark_direct_scatter { + typedef array T; + static void impl(const thrust::device_vector& permutation) { + run_benchmark_random("Direct Scatter", permutation, &benchmark_direct_scatter, &benchmark_direct_scatter); + } +}; + +template class F, typename Cons> +struct do_tests { + static void impl() { + F::impl(); + do_tests::impl(); + } + template + static void impl(const T& t) { + F::impl(t); + do_tests::impl(t); + } +}; + +template class F> +struct do_tests { + static void impl() {} + template + static void impl(const T& t) {} +}; + +#ifndef LOWER_BOUND +#define LOWER_BOUND 1 +#endif +#ifndef UPPER_BOUND +#define UPPER_BOUND 16 +#endif + +typedef static_range sizes; + +int main() { + do_tests::impl(); + do_tests::impl(); + do_tests::impl(); + do_tests::impl(); + int size = 80 * 8 * 100 * 256; + thrust::device_vector permutation = make_random_permutation(size); + do_tests::impl(permutation); + do_tests::impl(permutation); + do_tests::impl(permutation); + do_tests::impl(permutation); + +} diff --git a/nemo-retriever-ocr/cpp/trove/tests/block.cu b/nemo-retriever-ocr/cpp/trove/tests/block.cu new file mode 100644 index 0000000000000000000000000000000000000000..66da2ce859a8cd259d7c40465faae0ba06dfacd8 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/tests/block.cu @@ -0,0 +1,70 @@ +#include +#include +#include +#include +#include +#include + +template +__global__ void test_block_write(T* r, int l) { + typedef trove::array s_ary; + int global_index = threadIdx.x + blockIdx.x * blockDim.x; + for(int index = global_index; index < l; index += gridDim.x * blockDim.x) { + //Generate some test data to write out + s_ary d = trove::counting_array::impl(s * index); + + //The high performance vector memory accesses only function correctly + //if the warp is converged. Here we check. + if (trove::warp_converged()) { + //Warp converged, indices are contiguous, so we call the + //fast store + trove::store_array_warp_contiguous(r, index, d); + } else { + //Warp is not converged, call the slow store. + trove::store_array(r, index, d); + } + } +} + +template +__global__ void test_block_copy(const T* x, T* r, int l) { + typedef trove::array s_ary; + int global_index = threadIdx.x + blockIdx.x * blockDim.x; + + for(int index = global_index; index < l; index += gridDim.x * blockDim.x) { + + //The high performance vector memory accesses only function + //correctly if the warp is converged. Here we check. + if (trove::warp_converged()) { + //Warp converged, indices are contiguous, call the fast + //load and store + s_ary d = trove::load_array_warp_contiguous(x, index); + trove::store_array_warp_contiguous(r, index, d); + } else { + //Warp not converged, call the slow load and store + s_ary d = trove::load_array(x, index); + trove::store_array(r, index, d); + } + } +} + +int main() { + int l = 100000 * 256 + 17; + int int_length = l * 5; + thrust::device_vector d(int_length); + test_block_write<<<100*15, 256>>>( + thrust::raw_pointer_cast(d.data()), l); + thrust::device_vector g(int_length); + thrust::counting_iterator c(0); + thrust::copy(c, c + int_length, g.begin()); + std::cout << "test_block_write results pass: " << std::boolalpha << + thrust::equal(d.begin(), d.end(), g.begin()) << std::endl; + thrust::device_vector e(int_length); + test_block_copy<<<100*15, 256>>>( + thrust::raw_pointer_cast(g.data()), + thrust::raw_pointer_cast(e.data()), + l); + std::cout << "test_block_copy results pass: " << std::boolalpha << + thrust::equal(e.begin(), e.end(), g.begin()) << std::endl; + +} diff --git a/nemo-retriever-ocr/cpp/trove/tests/timer.h b/nemo-retriever-ocr/cpp/trove/tests/timer.h new file mode 100644 index 0000000000000000000000000000000000000000..62cd8165f23a9f83248d30d5c6fe8368e6779581 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/tests/timer.h @@ -0,0 +1,20 @@ +#pragma once + +struct cuda_timer { + cudaEvent_t start_event; + cudaEvent_t stop_event; + + void start() { + cudaEventCreate(&start_event); + cudaEventCreate(&stop_event); + cudaEventRecord(start_event, 0); + } + + float stop() { + cudaEventRecord(stop_event, 0); + cudaEventSynchronize(stop_event); + float time = 0; + cudaEventElapsedTime(&time, start_event, stop_event); + return time; + } +}; diff --git a/nemo-retriever-ocr/cpp/trove/trove/aos.h b/nemo-retriever-ocr/cpp/trove/trove/aos.h new file mode 100644 index 0000000000000000000000000000000000000000..73edd5107dca1b73df6f469fc913a0b95940947a --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/aos.h @@ -0,0 +1,279 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include +#include +#include +#include +#include + +namespace trove { + +namespace detail { + +template +struct size_in_range { + typedef typename dismember_type::type U; + static const int size = aliased_size::value; + static const bool value = (size > 1) && (size < 64); +}; + +template::value, bool r=size_in_range::value> +struct use_shfl { + static const bool value = false; +}; + +template +struct use_shfl { + static const bool value = true; +}; + +template +struct use_direct { + static const bool value = !(use_shfl::value); +}; + +} + + +template +__device__ typename enable_if::value, T>::type +load_warp_contiguous(const T* src) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + const T* warp_begin_src = src - warp_id; + typedef typename detail::dismember_type::type U; + const U* as_int_src = (const U*)warp_begin_src; + typedef array::value> int_store; + int_store loaded = warp_load(as_int_src, warp_id); + r2c_warp_transpose(loaded); + return detail::fuse(loaded); +} + +template +__device__ typename enable_if::value, T>::type +load_warp_contiguous(const T* src) { + return detail::divergent_load(src); +} + + +template +__device__ typename enable_if::value>::type +store_warp_contiguous(const T& data, T* dest) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + T* warp_begin_dest = dest - warp_id; + typedef typename detail::dismember_type::type U; + U* as_int_dest = (U*)warp_begin_dest; + typedef array::value> int_store; + int_store lysed = detail::lyse(data); + c2r_warp_transpose(lysed); + warp_store(lysed, as_int_dest, warp_id); +} + +template +__device__ typename enable_if::value>::type +store_warp_contiguous(const T& data, T* dest) { + detail::divergent_store(data, dest); +} + + +namespace detail { + +template +__device__ typename detail::dismember_type::type* +compute_address(T* src, int div, int mod) { + typedef typename detail::dismember_type::type U; +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 + T* base_ptr = __shfl_sync(WARP_CONVERGED, src, div); +#else + T* base_ptr = __shfl(src, div); +#endif + U* result = ((U*)(base_ptr) + mod); + return result; +} + +template +struct address_constants { + typedef typename detail::dismember_type::type U; + static const int m = aliased_size::value; + static const int mod_offset = WARP_SIZE % m; + static const int div_offset = WARP_SIZE / m; +}; + +template +__device__ void update_indices(int& div, int& mod) { + mod += address_constants::mod_offset; + if (mod >= address_constants::m) { + mod -= address_constants::m; + div += 1; + } + div += address_constants::div_offset; +} + + +template +struct indexed_load { + typedef typename detail::dismember_type::type U; + __device__ + static array impl(const T* src, int div, int mod) { + U result; + U* address = compute_address(src, div, mod); + result = *address; + update_indices(div, mod); + + + return array( + result, + indexed_load::impl(src, div, mod)); + } +}; + +template +struct indexed_load<1, T> { + typedef typename detail::dismember_type::type U; + __device__ + static array impl(const T* src, int div, int mod) { + U result; + U* address = compute_address(src, div, mod); + result = *address; + return array(result); + } +}; + +template +struct indexed_store { + typedef typename detail::dismember_type::type U; + __device__ + static void impl(const array& src, + T* dest, int div, int mod) { + U* address = compute_address(dest, div, mod); + *address = src.head; + update_indices(div, mod); + indexed_store::impl(src.tail, dest, div, mod); + } +}; + +template +struct indexed_store<1, T> { + typedef typename detail::dismember_type::type U; + __device__ + static void impl(const array& src, + T* dest, int div, int mod) { + U* address = compute_address(dest, div, mod); + *address = src.head; + } +}; + +template +__device__ +bool is_contiguous(int warp_id, const T* ptr) { + int neighbor_idx = (warp_id == 0) ? 0 : warp_id-1; +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 + const T* neighbor_ptr = __shfl_sync(WARP_CONVERGED, ptr, neighbor_idx); +#else + const T* neighbor_ptr = __shfl(ptr, neighbor_idx); +#endif + bool neighbor_contiguous = (warp_id == 0) ? true : (ptr - neighbor_ptr == sizeof(T)); + bool result = __all(neighbor_contiguous); + return result; +} + +template +__device__ typename enable_if::value, T>::type +load_dispatch(const T* src) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + // if (detail::is_contiguous(warp_id, src)) { + // return detail::load_warp_contiguous(src); + // } else { + typedef typename detail::dismember_type::type U; + typedef array::value> u_store; + u_store loaded = + detail::indexed_load::value, T>::impl( + src, + warp_id / address_constants::m, + warp_id % address_constants::m); + r2c_warp_transpose(loaded); + return detail::fuse(loaded); + // } +} + + + +template +__device__ typename enable_if::value, T>::type +load_dispatch(const T* src) { + return detail::divergent_load(src); +} + + +template +__device__ typename enable_if::value>::type +store_dispatch(const T& data, T* dest) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + // if (detail::is_contiguous(warp_id, dest)) { + // detail::store_warp_contiguous(data, dest); + // } else { + typedef typename detail::dismember_type::type U; + typedef array::value> u_store; + u_store lysed = detail::lyse(data); + c2r_warp_transpose(lysed); + detail::indexed_store::value, T>::impl( + lysed, dest, + warp_id / address_constants::m, + warp_id % address_constants::m); + // } +} + +template +__device__ typename enable_if::value>::type +store_dispatch(const T& data, T* dest) { + detail::divergent_store(data, dest); +} + + +} + +template +__device__ T load(const T* src) { + if (warp_converged()) { + return detail::load_dispatch(src); + } else { + return detail::divergent_load(src); + } +} + +template +__device__ void store(const T& data, T* dest) { + if (warp_converged()) { + detail::store_dispatch(data, dest); + } else { + detail::divergent_store(data, dest); + } +} + +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/array.h b/nemo-retriever-ocr/cpp/trove/trove/array.h new file mode 100644 index 0000000000000000000000000000000000000000..d42073c41c03a800509596943e81ed3ddf3c1eb5 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/array.h @@ -0,0 +1,269 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once + +namespace trove { + +template +struct array { + typedef T value_type; + typedef T head_type; + typedef array tail_type; + static const int size = m; + head_type head; + tail_type tail; + __host__ __device__ + array(head_type h, const tail_type& t) : head(h), tail(t) {} + __host__ __device__ + array() : head(), tail() {} + __host__ __device__ + array(const array& other) : head(other.head), tail(other.tail) {} + __host__ __device__ + array& operator=(const array& other) { + head = other.head; + tail = other.tail; + return *this; + } + __host__ __device__ + bool operator==(const array& other) const { + return (head == other.head) && (tail == other.tail); + } + __host__ __device__ + bool operator!=(const array& other) const { + return !operator==(other); + } +}; + +template +struct array { + typedef T value_type; + typedef T head_type; + static const int size = 1; + head_type head; + __host__ __device__ + array(head_type h) : head(h){} + __host__ __device__ + array() : head() {} + __host__ __device__ + array(const array& other) : head(other.head) {} + __host__ __device__ + array& operator=(const array& other) { + head = other.head; + return *this; + } + __host__ __device__ + bool operator==(const array& other) const { + return (head == other.head); + } + __host__ __device__ + bool operator!=(const array& other) const { + return !operator==(other); + } +}; + +template +struct array{}; + +namespace detail { + +template +struct get_impl { + __host__ __device__ static T& impl(array& src) { + return get_impl::impl(src.tail); + } + __host__ __device__ static T impl(const array& src) { + return get_impl::impl(src.tail); + } +}; + +template +struct get_impl { + __host__ __device__ static T& impl(array& src) { + return src.head; + } + __host__ __device__ static T impl(const array& src) { + return src.head; + } +}; + +} + +template +__host__ __device__ +T& get(array& src) { + return detail::get_impl::impl(src); +} + +template +__host__ __device__ +T get(const array& src) { + return detail::get_impl::impl(src); +} + +template +__host__ __device__ +array make_array() { + return array(); +} + +template +__host__ __device__ +array make_array(T a0) { + return array(a0); +} + +template +__host__ __device__ +array make_array(T a0, T a1) { + return array(a0, + make_array(a1)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2) { + return array(a0, + make_array(a1, a2)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2, T a3) { + return array(a0, + make_array(a1, a2, a3)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2, T a3, T a4) { + return array(a0, + make_array(a1, a2, a3, a4)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2, T a3, T a4, + T a5) { + return array(a0, + make_array(a1, a2, a3, a4, a5)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2, T a3, T a4, + T a5, T a6) { + return array(a0, + make_array(a1, a2, a3, a4, a5, + a6)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2, T a3, T a4, + T a5, T a6, T a7) { + return array(a0, + make_array(a1, a2, a3, a4, a5, + a6, a7)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2, T a3, T a4, + T a5, T a6, T a7, T a8) { + return array(a0, + make_array(a1, a2, a3, a4, a5, + a6, a7, a8)); +} + +template +__host__ __device__ +array make_array(T a0, T a1, T a2, T a3, T a4, + T a5, T a6, T a7, T a8, T a9) { + return array(a0, + make_array(a1, a2, a3, a4, a5, + a6, a7, a8, a9)); +} + + +namespace detail { + +template +struct make_array_impl { + typedef array result_type; + __host__ __device__ + static result_type impl(T ary[s]) { + return result_type(ary[0], + make_array_impl::impl(ary+1)); + } +}; + +template +struct make_array_impl { + typedef array result_type; + __host__ __device__ + static result_type impl(T ary[1]) { + return result_type(ary[0]); + } +}; + + +template +struct make_carray_impl { + typedef array array_type; + __host__ __device__ + static void impl(const array_type& ary, T result[s]) { + result[0] = ary.head; + make_carray_impl::impl(ary.tail, result+1); + } +}; + +template +struct make_carray_impl { + typedef array array_type; + __host__ __device__ + static void impl(const array_type& ary, T result[1]) { + result[0] = ary.head; + } +}; + +} //end namespace detail + +template +__host__ __device__ +array make_array(T cary[s]) { + return detail::make_array_impl::impl(cary); +} + +template +__host__ __device__ +void make_carray(const array& ary, + T result[s]) { + detail::make_carray_impl::impl(ary, result); +} + +} //end namespace trove diff --git a/nemo-retriever-ocr/cpp/trove/trove/block.h b/nemo-retriever-ocr/cpp/trove/trove/block.h new file mode 100644 index 0000000000000000000000000000000000000000..78a432c7a1cc8c04895ddfe693a096f512af8369 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/block.h @@ -0,0 +1,67 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include + +namespace trove { + + +template +__device__ +trove::array load_array_warp_contiguous(const T* src, const I& idx) { + typedef trove::array array_type; + const array_type* src_ptr = (const array_type*)(src) + idx; + return load_warp_contiguous(src_ptr); +} + +template +__device__ +trove::array load_array(const T* src, const I& idx) { + typedef trove::array array_type; + const array_type* src_ptr = (const array_type*)(src) + idx; + return *src_ptr; +} + +template +__device__ +void store_array_warp_contiguous(T* dest, const I& idx, const trove::array& src) { + typedef trove::array array_type; + array_type* dest_ptr = (array_type*)(dest) + idx; + store_warp_contiguous(src, dest_ptr); +} + +template +__device__ +void store_array(T* dest, const I& idx, const trove::array& src) { + typedef trove::array array_type; + array_type* dest_ptr = (array_type*)(dest) + idx; + *dest_ptr = src; +} + +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/detail/dismember.h b/nemo-retriever-ocr/cpp/trove/trove/detail/dismember.h new file mode 100644 index 0000000000000000000000000000000000000000..3425f6496fcdc59da6947c66368633db5e1d96ad --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/detail/dismember.h @@ -0,0 +1,132 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include +#include + +namespace trove { +namespace detail { + + +template::value, + bool use_int2=size_multiple_power_of_two::value, + bool use_int4=size_multiple_power_of_two::value > +struct dismember_type { + typedef char type; +}; + +template +struct dismember_type { + typedef int type; +}; + +template +struct dismember_type { + typedef int2 type; +}; + +template +struct dismember_type { + typedef int4 type; +}; + + +template +struct aliased_size { + static const int value = sizeof(T) / sizeof(U); + //Assert sizeof(T) % sizeof(U) == 0 + THRUST_STATIC_ASSERT(sizeof(T) % sizeof(U) == 0); +}; + +template::type, + int r=aliased_size::value> +struct dismember { + typedef array result_type; + static const int idx = aliased_size::value - r; + __host__ __device__ + static result_type impl(const T& t) { + U tmp; + memcpy(&tmp, reinterpret_cast(&t) + idx * sizeof(U), sizeof(U)); + return result_type(tmp, dismember::impl(t)); + } +}; + +template +struct dismember { + typedef array result_type; + static const int idx = aliased_size::value - 1; + __host__ __device__ + static result_type impl(const T& t) { + U tmp; + memcpy(&tmp, reinterpret_cast(&t) + idx * sizeof(U), sizeof(U)); + return result_type(tmp); + } +}; + + +template::type, + int r=aliased_size::value> +struct remember { + static const int idx = aliased_size::value - r; + __host__ __device__ + static void impl(const array& d, T& t) { + memcpy(reinterpret_cast(&t) + idx * sizeof(U), &d.head, sizeof(d.head)); + remember::impl(d.tail, t); + } +}; + +template +struct remember { + static const int idx = aliased_size::value - 1; + __host__ __device__ + static void impl(const array& d, T& t) { + memcpy(reinterpret_cast(&t) + idx * sizeof(U), &d.head, sizeof(d.head)); + } +}; + + +template +__host__ __device__ +array::value> lyse(const T& in) { + return detail::dismember::impl(in); +} + +template +__host__ __device__ +T fuse(const array::value>& in) { + T result; + detail::remember::impl(in, result); + return result; +} + +} +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/detail/fallback.h b/nemo-retriever-ocr/cpp/trove/trove/detail/fallback.h new file mode 100644 index 0000000000000000000000000000000000000000..aec551ffc34a670bf1217e0cae8a770efae3288c --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/detail/fallback.h @@ -0,0 +1,133 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include + +namespace trove { +namespace detail { + +template +struct divergent_loader { + typedef typename detail::dismember_type::type U; + __device__ + static array impl(const U* src) { + return array(*src, + divergent_loader::impl(src+1)); + } + __device__ + static array impl(const T* src) { + return impl((U*)src); + } +}; + +template +struct divergent_loader<1, T> { + typedef typename detail::dismember_type::type U; + __device__ + static array impl(const U* src) { + return array(*src); + } + __device__ + static array impl(const T* src) { + return impl((U*)src); + } +}; + +template +struct use_divergent { + static const bool value = (sizeof(T) % 4) == 0; +}; + +template +__device__ +typename enable_if::value, T>::type +divergent_load(const T* src) { + typedef typename detail::dismember_type::type U; + typedef array::value> u_store; + u_store loaded = + detail::divergent_loader::value, T>::impl( + src); + return detail::fuse(loaded); +} + +template +__device__ +typename enable_if::value, T>::type +divergent_load(const T* src) { + return *src; +} + +template +struct divergent_storer { + typedef typename detail::dismember_type::type U; + __device__ + static void impl(const array& data, U* dest) { + *dest = data.head; + divergent_storer::impl(data.tail, dest+1); + } + __device__ + static void impl(const array& data, const T* dest) { + return impl(data, (U*)dest); + } +}; + +template +struct divergent_storer<1, T> { + typedef typename detail::dismember_type::type U; + __device__ + static void impl(const array& data, U* dest) { + *dest = data.head; + } + __device__ + static void impl(const array& data, const T* dest) { + return impl(data, (U*)dest); + } +}; + +template +__device__ +typename enable_if::value>::type +divergent_store(const T& data, T* dest) { + typedef typename detail::dismember_type::type U; + typedef array::value> u_store; + u_store lysed = detail::lyse(data); + detail::divergent_storer::value, T>::impl( + lysed, dest); +} + +template +__device__ +typename enable_if::value>::type +divergent_store(const T& data, T* dest) { + *dest = data; +} + + +} +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/memory.h b/nemo-retriever-ocr/cpp/trove/trove/memory.h new file mode 100644 index 0000000000000000000000000000000000000000..b1cc19f70230e0821dcf78efddea1f941a6f7c30 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/memory.h @@ -0,0 +1,170 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include + +namespace trove { +namespace detail { + +template +struct warp_store_array {}; + +template +struct warp_store_array > { + __host__ __device__ static void impl( + const array& d, + T* ptr, int offset, int stride) { + ptr[offset] = d.head; + warp_store_array >::impl( + d.tail, ptr, offset + stride, stride); + } +}; + +template +struct warp_store_array > { + __host__ __device__ static void impl( + const array& d, + T* ptr, int offset, int stride) { + ptr[offset] = d.head; + } +}; + +template +struct uncoalesced_store_array{}; + +template +struct uncoalesced_store_array > { + __host__ __device__ static void impl( + const array& d, + T* ptr, + int offset=0, + int stride=1) { + ptr[offset] = d.head; + uncoalesced_store_array >::impl(d.tail, ptr, offset+1, + stride); + } + __host__ __device__ static void impl( + const array& d, + volatile T* ptr, + int offset=0, + int stride=1) { + ptr[offset] = d.head; + uncoalesced_store_array >::impl(d.tail, ptr, offset+1, + stride); + } +}; + +template +struct uncoalesced_store_array > { + __host__ __device__ static void impl( + const array& d, + T* ptr, + int offset=0, + int stride=1) { + ptr[offset] = d.head; + } + __host__ __device__ static void impl( + const array& d, + volatile T* ptr, + int offset=0, + int stride=1) { + ptr[offset] = d.head; + } +}; + +template +struct warp_load_array{}; + +template +struct warp_load_array > { + __host__ __device__ static array impl(const T* ptr, + int offset, + int stride=32) { + return array(ptr[offset], + warp_load_array >::impl(ptr, offset+stride, stride)); + } + __host__ __device__ static array impl(const volatile T* ptr, + int offset, + int stride=32) { + return array(ptr[offset], + warp_load_array >::impl(ptr, offset+stride, stride)); + } +}; + +template +struct warp_load_array > { + __host__ __device__ static array impl(const T* ptr, + int offset, + int stride=32) { + return array(ptr[offset]); + } + __host__ __device__ static array impl(const volatile T* ptr, + int offset, + int stride=32) { + return array(ptr[offset]); + } +}; + +} //end namespace detail + +template +__host__ __device__ void warp_store(const Array& t, + typename Array::head_type* ptr, + int offset, int stride=32) { + detail::warp_store_array::impl(t, ptr, offset, stride); +} + +template +__host__ __device__ Array warp_load(const typename Array::head_type* ptr, + int offset, int stride=32) { + return detail::warp_load_array::impl(ptr, offset, stride); +} + +template +__host__ __device__ Array warp_load( + const volatile typename Array::head_type* ptr, + int offset, int stride=32) { + return detail::warp_load_array::impl(ptr, offset, stride); +} + +template +__host__ __device__ void uncoalesced_store(const Array& t, + typename Array::head_type* ptr, + int stride=1) { + detail::uncoalesced_store_array::impl(t, ptr, 0, stride); +} + +template +__host__ __device__ void uncoalesced_store(const Array& t, + volatile typename Array::head_type* ptr, + int stride=1) { + detail::uncoalesced_store_array::impl(t, ptr, 0, stride); +} + +} //end namespace trove diff --git a/nemo-retriever-ocr/cpp/trove/trove/print_array.h b/nemo-retriever-ocr/cpp/trove/trove/print_array.h new file mode 100644 index 0000000000000000000000000000000000000000..a5f11a9a88e0369d4111e43fe41faa2f5b11fb51 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/print_array.h @@ -0,0 +1,48 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include + +namespace trove { + +template +std::ostream& operator<<(std::ostream& strm, const array& ary) { + strm << ary.head; + return strm; +} + +template +std::ostream& operator<<(std::ostream& strm, const array& ary) { + strm << ary.head << " "; + strm << ary.tail; + return strm; +} + + +} //ends namespace trove diff --git a/nemo-retriever-ocr/cpp/trove/trove/ptr.h b/nemo-retriever-ocr/cpp/trove/trove/ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..50f0fd02c44f8cb01832ebc713d1e458c4bdd528 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/ptr.h @@ -0,0 +1,80 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include + +namespace trove { +namespace detail { + +template +struct coalesced_ref { + T* m_ptr; + __device__ explicit coalesced_ref(T* ptr) : m_ptr(ptr) {} + + __device__ operator T() { + return trove::load(m_ptr); + } + __device__ coalesced_ref& operator=(const T& data) { + trove::store(data, m_ptr); + return *this; + } + + __device__ coalesced_ref& operator=(const coalesced_ref& other) { + if (warp_converged()) { + T data = detail::load_dispatch(other.m_ptr); + detail::store_dispatch(data, m_ptr); + } else { + T data = detail::divergent_load(other.m_ptr); + detail::divergent_store(data, m_ptr); + } + return *this; + } +}; +} + +template +struct coalesced_ptr { + T* m_ptr; + __device__ coalesced_ptr(T* ptr) : m_ptr(ptr) {} + __device__ trove::detail::coalesced_ref operator*() { + return trove::detail::coalesced_ref(m_ptr); + } + template + __device__ trove::detail::coalesced_ref operator[](const I& idx) { + return trove::detail::coalesced_ref(m_ptr + idx); + } + __device__ operator T*() { + return m_ptr; + } +}; + + + + + +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/rotate.h b/nemo-retriever-ocr/cpp/trove/trove/rotate.h new file mode 100644 index 0000000000000000000000000000000000000000..12284e08fa1b17717e847e3ef84576729e1ea7b0 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/rotate.h @@ -0,0 +1,117 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include + +namespace trove { +namespace detail { + +template +struct rotate_elements; + +template +struct rotate_elements_helper { + static const int size = Array::size; + static const int other = (i + j) % size; + static const bool new_non_terminal = j < size-2; + __host__ __device__ + static void impl(const Array& t, int a, Array& r) { + if (a & i) + trove::get(r) = trove::get(t); + rotate_elements_helper::impl(t, a, r); + } +}; + +template +struct rotate_elements_helper { + static const int size = Array::size; + static const int other = (i + j) % size; + __host__ __device__ + static void impl(const Array& t, int a, Array& r) { + if (a & i) + trove::get(r) = trove::get(t); + } +}; + + +template +struct rotate_elements{ + static const int size = Array::size; + static const bool non_terminal = j < size-1; + __host__ __device__ + static void impl(const Array& t, int a, Array& r) { + rotate_elements_helper::impl(t, a, r); + } +}; + +template +struct rotate_impl; + +template +struct rotate_impl_helper { + static const int size = Array::size; + static const int next_i = i * 2; + __host__ __device__ + static Array impl(const Array& t, int a) { + Array rotated = t; + rotate_elements::impl(t, a, rotated); + return rotate_impl::impl(rotated, a); + } +}; + +template +struct rotate_impl_helper { + static const int size = Array::size; + __host__ __device__ + static Array impl(const Array& t, int a) { + Array rotated = t; + rotate_elements::impl(t, a, rotated); + return rotated; + } +}; + +template +struct rotate_impl { + static const int size = Array::size; + static const int next_i = i * 2; + static const bool non_terminal = next_i < size; + __host__ __device__ + static Array impl(const Array& t, int a) { + return rotate_impl_helper::impl(t, a); + } +}; + +} //ends namespace detail + +template +__host__ __device__ +array rotate(const array& t, int a) { + return detail::rotate_impl, 1>::impl(t, a); +} + +} //ends namespace trove diff --git a/nemo-retriever-ocr/cpp/trove/trove/shfl.h b/nemo-retriever-ocr/cpp/trove/trove/shfl.h new file mode 100644 index 0000000000000000000000000000000000000000..62d6dcb86ef1f62f7b2439d5681f6dbc4351efa6 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/shfl.h @@ -0,0 +1,106 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include + +namespace trove { +namespace detail { + +template +struct shuffle { + __device__ + static void impl(array& d, const int& i) { +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 + d.head = __shfl_sync(WARP_CONVERGED, d.head, i); +#else + d.head = __shfl(d.head, i); +#endif + shuffle::impl(d.tail, i); + } +}; + +template<> +struct shuffle<1> { + __device__ + static void impl(array& d, const int& i) { +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 + d.head = __shfl_sync(WARP_CONVERGED, d.head, i); +#else + d.head = __shfl(d.head, i); +#endif + } +}; + +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 +template +struct shuffle_sync { + __device__ + static void impl(unsigned mask, array& d, const int& i) { + d.head = __shfl_sync(mask, d.head, i); + shuffle_sync::impl(mask, d.tail, i); + } +}; + +template<> +struct shuffle_sync<1> { + __device__ + static void impl(unsigned mask, array& d, const int& i) { + d.head = __shfl_sync(mask, d.head, i); + } +}; +#endif + +} +} + +template +__device__ +T __shfl(const T& t, const int& i) { + typedef trove::array::value> + lysed_array; + lysed_array lysed = trove::detail::lyse(t); + trove::detail::shuffle + ::impl(lysed, i); + return trove::detail::fuse(lysed); +} + +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 +template +__device__ +T __shfl_sync(unsigned mask, const T& t, const int& i) { + typedef trove::array::value> + lysed_array; + lysed_array lysed = trove::detail::lyse(t); + trove::detail::shuffle_sync + ::impl(mask, lysed, i); + return trove::detail::fuse(lysed); +} +#endif diff --git a/nemo-retriever-ocr/cpp/trove/trove/static_gcd.h b/nemo-retriever-ocr/cpp/trove/trove/static_gcd.h new file mode 100644 index 0000000000000000000000000000000000000000..93bb9ea527ac409d84e37e76545ea44239d123a6 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/static_gcd.h @@ -0,0 +1,77 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +namespace trove { + +template +struct static_gcd; + +namespace detail { + +template +struct static_gcd_helper { + static const int value = static_gcd<(u>>1), (v>>1)>::value << 1; +}; + +template +struct static_gcd_helper { + static const int value = static_gcd<(u>>1), v>::value; +}; + +template +struct static_gcd_helper { + static const int value = static_gcd>1)>::value; +}; + +template +struct static_gcd_helper { + static const int reduced_u = (u > v) ? ((u - v) >> 1) : ((v - u) >> 1); + static const int reduced_v = (u > v) ? v : u; + static const int value = static_gcd::value; +}; +} + +template +struct static_gcd { + static const bool u_odd = (u & 0x1) == 1; + static const bool v_odd = (v & 0x1) == 1; + static const bool equal = u == v; + static const int value = equal ? u : detail::static_gcd_helper::value; +}; + +template +struct static_gcd<0, v> { + static const bool value = v; +}; + +template +struct static_gcd { + static const bool value = u; +}; + +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/static_mod_inverse.h b/nemo-retriever-ocr/cpp/trove/trove/static_mod_inverse.h new file mode 100644 index 0000000000000000000000000000000000000000..dbdc1a8a414a472dec03f7a0bd54ee7fb78eda16 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/static_mod_inverse.h @@ -0,0 +1,54 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once + +namespace trove { + +template +struct static_mod_inverse; + +template +struct static_mod_inverse_helper { + //If you get this returned, it means the mod inverse doesn't exist. + static const int value = -1; +}; + +template +struct static_mod_inverse_helper { + static const int fx = (r * a) % m; + static const bool found = (fx == 1); + static const int value = found ? r : static_mod_inverse::value; +}; + +template +struct static_mod_inverse { + static const bool done = r == m; + static const int value = static_mod_inverse_helper::value; +}; + +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/transpose.h b/nemo-retriever-ocr/cpp/trove/trove/transpose.h new file mode 100644 index 0000000000000000000000000000000000000000..1cdb7f84465ecb05ec59d2f13c51dcfe33f13c2f --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/transpose.h @@ -0,0 +1,686 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include +#include +#include +#include +#include +#include + +namespace trove { +namespace detail { + +struct odd{}; +struct power_of_two{}; +struct composite{}; + +template::value, bool isodd=is_odd::value> +struct tx_algorithm { + typedef composite type; +}; + +template +struct tx_algorithm { + typedef power_of_two type; +}; + +template +struct tx_algorithm { + typedef odd type; +}; + +template::type> +struct c2r_offset_constants{}; + +template +struct c2r_offset_constants { + static const int offset = WARP_SIZE - static_mod_inverse::value; + static const int rotate = static_mod_inverse::value; + static const int permute = static_mod_inverse::value; +}; + +template +struct c2r_offset_constants { + static const int offset = WARP_SIZE - WARP_SIZE/m; + static const int permute = m - 1; +}; + +template +struct c2r_offset_constants { + static const int c = static_gcd::value; + static const int k = static_mod_inverse::value; +}; + +template::type> +struct r2c_offset_constants{}; + +template +struct r2c_offset_constants { + static const int permute = static_mod_inverse::value; +}; + +template +struct r2c_offset_constants : + c2r_offset_constants { +}; + +template class Permute, int position=0> +struct tx_permute_impl{}; + +template class Permute, int position> +struct tx_permute_impl, Permute, position> { + typedef array Remaining; + static const int idx = Permute::value; + template + __host__ __device__ + static Remaining impl(const Source& src) { + return Remaining( + trove::get(src), + tx_permute_impl, Permute, position+1>::impl( + src)); + } +}; + +template class Permute, int position> +struct tx_permute_impl, Permute, position> { + typedef array Remaining; + static const int idx = Permute::value; + template + __host__ __device__ + static Remaining impl(const Source& src) { + return Remaining(trove::get(src)); + } +}; + + +template +struct affine_modular_fn { + template + struct eval { + static const int value = (a * x + b) % m; + }; +}; + +template +struct composite_c2r_permute_fn { + static const int o = WARP_SIZE % m; + static const int c = static_gcd::value; + static const int p = m / c; + template + struct eval { + static const int value = (x * o - (x / p)) % m; + }; +}; + +template +struct composite_r2c_permute_fn { + template + struct eval { + static const int value = + inverse::template eval, x>::value; + }; +}; + + +template +__host__ __device__ Array c2r_tx_permute(const Array& t) { + return tx_permute_impl< + Array, + affine_modular_fn::permute>::template eval>::impl(t); +} + + + +template +__host__ __device__ Array composite_c2r_tx_permute(const Array& t) { + return tx_permute_impl< + Array, + composite_c2r_permute_fn::template eval>::impl(t); +} + +template +__host__ __device__ Array composite_r2c_tx_permute(const Array& t) { + return tx_permute_impl< + Array, + composite_r2c_permute_fn::template eval>::impl(t); +} + + + +template +__host__ __device__ Array r2c_tx_permute(const Array& t) { + return tx_permute_impl< + Array, + affine_modular_fn::permute>::template eval>::impl(t); +} + + +template +struct c2r_compute_offsets_impl{}; + +template +struct c2r_compute_offsets_impl, b, o> { + typedef array Array; + __device__ + static Array impl(int offset) { + if (offset >= b) { + offset -= b; + } //Poor man's x % b. Requires that o < b. + return Array(offset, + c2r_compute_offsets_impl, b, o>:: + impl(offset + o)); + } +}; + +template +struct c2r_compute_offsets_impl, b, o> { + typedef array Array; + __device__ + static Array impl(int offset) { + if (offset >= b) { + offset -= b; + } //Poor man's x % b. Requires that o < b. + return Array(offset); + } +}; + +template +struct c2r_compute_initial_offset {}; + +template +struct c2r_compute_initial_offset { + typedef c2r_offset_constants constants; + __device__ static int impl() { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int initial_offset = ((WARP_SIZE - warp_id) * constants::offset) + & WARP_MASK; + return initial_offset; + } +}; + +template +struct c2r_compute_initial_offset { + __device__ static int impl() { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int initial_offset = ((warp_id * (WARP_SIZE + 1)) >> + static_log::value) + & WARP_MASK; + return initial_offset; + } +}; + +template +struct r2c_compute_initial_offset {}; + +template +struct r2c_compute_initial_offset { + __device__ static int impl() { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int initial_offset = (warp_id * m) & WARP_MASK; + return initial_offset; + } +}; + + +template +__device__ +array c2r_compute_offsets() { + typedef c2r_offset_constants constants; + int initial_offset = c2r_compute_initial_offset::impl(); + return c2r_compute_offsets_impl, + WARP_SIZE, + constants::offset>::impl(initial_offset); +} + +template +struct c2r_compute_composite_offsets{}; + +template +struct c2r_compute_composite_offsets, m, p> { + static const int n = WARP_SIZE; + static const int mod_n = n - 1; + static const int c = static_gcd::value; + static const int k = static_mod_inverse::value; + static const int mod_c = c - 1; + static const int log_c = static_log::value; + static const int n_div_c = n / c; + static const int mod_n_div_c = n_div_c - 1; + static const int log_n_div_c = static_log::value; + typedef array result_type; + __host__ __device__ static result_type impl(int idx, int col) { + int offset = ((((idx >> log_c) * k) & mod_n_div_c) + + ((idx & mod_c) << log_n_div_c)) & mod_n; + int new_idx = idx + n - 1; + new_idx = (p == m - c + (col & mod_c)) ? new_idx + m : new_idx; + return + result_type(offset, + c2r_compute_composite_offsets, m, p+1> + ::impl(new_idx, col)); + + } +}; + +template +struct c2r_compute_composite_offsets, m, p> { + static const int n = WARP_SIZE; + static const int mod_n = n - 1; + static const int c = static_gcd::value; + static const int k = static_mod_inverse::value; + static const int mod_c = c - 1; + static const int log_c = static_log::value; + static const int n_div_c = n / c; + static const int mod_n_div_c = n_div_c - 1; + static const int log_n_div_c = static_log::value; + typedef array result_type; + __host__ __device__ static result_type impl(int idx, int col) { + int offset = ((((idx >> log_c) * k) & mod_n_div_c) + + ((idx & mod_c) << log_n_div_c)) & mod_n; + return result_type(offset); + + } +}; + + +template +struct r2c_offsets { + static const int value = (offset * index) % bound; +}; + +template +struct r2c_compute_offsets_impl{}; + +template +struct r2c_compute_offsets_impl, index, m, odd> { + typedef array Array; + static const int offset = (WARP_SIZE % m * index) % m; + __device__ + static Array impl(int initial_offset) { + int current_offset = (initial_offset + offset) & WARP_MASK; + return Array(current_offset, + r2c_compute_offsets_impl, + index + 1, m, odd>::impl(initial_offset)); + } +}; + +template +struct r2c_compute_offsets_impl, index, m, odd> { + typedef array Array; + static const int offset = (WARP_SIZE % m * index) % m; + __device__ + static Array impl(int initial_offset) { + int current_offset = (initial_offset + offset) & WARP_MASK; + return Array(current_offset); + } +}; + + +template +struct r2c_compute_offsets_impl, index, m, power_of_two> { + typedef array Array; + __device__ + static Array impl(int offset, int lb) { + int new_offset = (offset == lb) ? offset + m - 1 : offset - 1; + return Array(offset, + r2c_compute_offsets_impl, index + 1, m, power_of_two>::impl(new_offset, lb)); + } +}; + +template +struct r2c_compute_offsets_impl, index, m, power_of_two> { + typedef array Array; + __device__ + static Array impl(int offset, int lb) { + return Array(offset); + } +}; + + +template +struct r2c_compute_composite_offsets{}; + +template +struct r2c_compute_composite_offsets, m> { + static const int n = WARP_SIZE; + static const int mod_n = n - 1; + static const int c = static_gcd::value; + static const int n_div_c = n / c; + static const int log_n_div_c = static_log::value; + typedef array result_type; + __host__ __device__ static result_type impl(int col, int offset, int lb, int ub) { + int new_offset = offset + 1; + new_offset = (new_offset == ub) ? lb : new_offset; + return + result_type(offset & mod_n, + r2c_compute_composite_offsets, m> + ::impl(col, new_offset, lb, ub)); + + } +}; + +template +struct r2c_compute_composite_offsets, m> { + static const int n = WARP_SIZE; + static const int mod_n = n - 1; + static const int c = static_gcd::value; + static const int n_div_c = n / c; + static const int log_n_div_c = static_log::value; + typedef array result_type; + __host__ __device__ static result_type impl(int col, int offset, int lb, int ub) { + return result_type(offset & mod_n); + + } +}; + +template +__device__ +array r2c_compute_offsets() { + typedef r2c_offset_constants constants; + typedef array result_type; + int initial_offset = r2c_compute_initial_offset::impl(); + return r2c_compute_offsets_impl::impl(initial_offset); +} + + +template +struct warp_shuffle {}; + +template +struct warp_shuffle, array > { + __device__ static void impl(array& d, + const array& i) { +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 + d.head = __shfl_sync(WARP_CONVERGED, d.head, i.head); +#else + d.head = __shfl(d.head, i.head); +#endif + warp_shuffle, array >::impl(d.tail, + i.tail); + } +}; + +template +struct warp_shuffle, array > { + __device__ static void impl(array& d, + const array& i) { +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 + d.head = __shfl_sync(WARP_CONVERGED, d.head, i.head); +#else + d.head = __shfl(d.head, i.head); +#endif + } +}; + + +template +struct c2r_compute_indices_impl {}; + +template +struct c2r_compute_indices_impl { + __device__ static void impl(Array& indices, int& rotation) { + indices = detail::c2r_compute_offsets(); + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int size = Array::size; + int r = detail::c2r_offset_constants::rotate; + rotation = (warp_id * r) % size; + } +}; + +template +struct c2r_compute_indices_impl { + __device__ static void impl(Array& indices, int& rotation) { + indices = detail::c2r_compute_offsets(); + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int size = Array::size; + rotation = (size - warp_id) & (size - 1); + } +}; + +template +struct c2r_compute_indices_impl { + __device__ static void impl(Array& indices, int& rotation) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + + indices = detail::c2r_compute_composite_offsets:: + impl(warp_id, warp_id); + rotation = warp_id % Array::size; + } +}; + +template +struct c2r_warp_transpose_impl {}; + +template +struct c2r_warp_transpose_impl { + __device__ static void impl(Array& src, + const Indices& indices, + const int& rotation) { + detail::warp_shuffle::impl(src, indices); + src = rotate(detail::c2r_tx_permute(src), rotation); + } +}; + +template +struct c2r_warp_transpose_impl { + __device__ static void impl(Array& src, + const Indices& indices, + const int& rotation) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int pre_rotation = warp_id >> + (LOG_WARP_SIZE - + static_log::value); + src = rotate(src, pre_rotation); + c2r_warp_transpose_impl::impl + (src, indices, rotation); + } +}; + +template +struct c2r_warp_transpose_impl { + __device__ static void impl(Array& src, + const Indices& indices, + const int& rotation) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int pre_rotation = warp_id >> static_log::value>::value; + src = rotate(src, pre_rotation); + detail::warp_shuffle::impl(src, indices); + src = rotate(src, rotation); + src = composite_c2r_tx_permute(src); + } +}; + +template +struct r2c_compute_indices_impl {}; + +template +struct r2c_compute_indices_impl { + __device__ static void impl(Array& indices, int& rotation) { + indices = + detail::r2c_compute_offsets(); + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int size = Array::size; + int r = + size - detail::r2c_offset_constants::permute; + rotation = (warp_id * r) % size; + } +}; + +template +struct r2c_compute_indices_impl { + static const int m = Array::size; + static const int log_m = static_log::value; + static const int clear_m = ~(m-1); + static const int n = WARP_SIZE; + static const int log_n = static_log::value; + static const int mod_n = n-1; + static const int n_div_m = WARP_SIZE / m; + static const int log_n_div_m = static_log::value; + __device__ static void impl(Array& indices, int& rotation) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + int size = Array::size; + rotation = warp_id % size; + int initial_offset = ((warp_id << log_m) + (warp_id >> log_n_div_m)) & mod_n; + int lb = initial_offset & clear_m; + indices = r2c_compute_offsets_impl::impl(initial_offset, lb); + } +}; + +template +struct r2c_compute_indices_impl { + static const int size = Array::size; + static const int c = static_gcd::value; + __device__ static void impl(Array& indices, int& rotation) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + rotation = size - (warp_id % size); + int lb = (size * warp_id) & WARP_MASK; + int ub = lb + size; + int offset = lb + warp_id / (WARP_SIZE/c); + indices = detail::r2c_compute_composite_offsets:: + impl(warp_id, offset, lb, ub); + } +}; + +template +struct r2c_warp_transpose_impl {}; + +template +struct r2c_warp_transpose_impl { + __device__ static void impl(Array& src, + const Indices& indices, + const int& rotation) { + Array rotated = rotate(src, rotation); + detail::warp_shuffle::impl(rotated, indices); + src = detail::r2c_tx_permute(rotated); + } +}; + +template +struct r2c_warp_transpose_impl { + __device__ static void impl(Array& src, + const Indices& indices, + const int& rotation) { + Array rotated = rotate(src, rotation); + detail::warp_shuffle::impl(rotated, indices); + const int size = Array::size; + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + src = rotate(detail::r2c_tx_permute(rotated), + (size-warp_id/(WARP_SIZE/size))%size); + } +}; + +template +struct r2c_warp_transpose_impl { + static const int c = static_gcd::value; + static const int size = Array::size; + __device__ static void impl(Array& src, + const Indices& indices, + const int& rotation) { + int warp_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) & WARP_MASK; + src = composite_r2c_tx_permute(src); + src = rotate(src, rotation); + detail::warp_shuffle::impl(src, indices); + src = rotate(src, size - (warp_id/(WARP_SIZE/c))); + } +}; + +} //end namespace detail + +template +__device__ void c2r_compute_indices(array& indices, int& rotation) { + typedef array Array; + detail::c2r_compute_indices_impl< + Array, + typename detail::tx_algorithm::type> + ::impl(indices, rotation); + +} + +template +__device__ void c2r_warp_transpose(array& src, + const array& indices, + int rotation) { + typedef array Array; + detail::c2r_warp_transpose_impl< + Array, array, + typename detail::tx_algorithm::type>:: + impl(src, indices, rotation); +} + +template +__device__ void c2r_warp_transpose(array& src) { + typedef array Array; + typedef array indices_array; + indices_array indices; + int rotation; + c2r_compute_indices(indices, rotation); + + detail::c2r_warp_transpose_impl< + Array, array, + typename detail::tx_algorithm::type>:: + impl(src, indices, rotation); +} + +template +__device__ void r2c_compute_indices(array& indices, int& rotation) { + typedef array Array; + detail::r2c_compute_indices_impl< + Array, typename detail::tx_algorithm::type> + ::impl(indices, rotation); + +} + +template +__device__ void r2c_warp_transpose(array& src, + const array& indices, + int rotation) { + typedef array Array; + detail::r2c_warp_transpose_impl< + Array, array, + typename detail::tx_algorithm::type> + ::impl(src, indices, rotation); +} + +template +__device__ void r2c_warp_transpose(array& src) { + typedef array Array; + typedef array indices_array; + indices_array indices; + int rotation; + r2c_compute_indices(indices, rotation); + + detail::r2c_warp_transpose_impl< + Array, array, + typename detail::tx_algorithm::type> + ::impl(src, indices, rotation); +} + +} //end namespace trove diff --git a/nemo-retriever-ocr/cpp/trove/trove/utility.h b/nemo-retriever-ocr/cpp/trove/trove/utility.h new file mode 100644 index 0000000000000000000000000000000000000000..836a86f82fd22b47982db7eee6059015a2c12499 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/utility.h @@ -0,0 +1,162 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once +#include + +namespace trove { + +template +struct counting_array{}; + +template +struct counting_array > { + typedef array Array; + __host__ __device__ + static Array impl(T v=0, T i=1) { + return Array(v, + counting_array >::impl(v + i, i)); + } +}; + +template +struct counting_array > { + __host__ __device__ + static array impl(T v, T i=1) { + return make_array(v); + } +}; + +template +struct sum_array {}; + +template +struct sum_array > { + typedef array Array; + __host__ __device__ + static T impl(const Array& a, const T& p) { + return sum_array::impl(a.tail, p + a.head); + } +}; + +template +struct sum_array > { + typedef array Array; + __host__ __device__ + static T impl(const Array& a, const T& p) { + return p + a.head; + } +}; + +template +__host__ __device__ T sum(const array& a) { + return sum_array >::impl(a, 0); +} + +template +struct static_log { + static const int value = 1 + static_log< (m >> 1) >::value; +}; + +template<> +struct static_log<1> { + static const int value = 0; +}; + +template<> +struct static_log<0> { + //This functions as a static assertion + //Don't take the log of 0!! +}; + +template +struct is_power_of_two { + static const bool value = (m & (m-1)) == 0; +}; + +template +struct is_odd { + static const bool value = (m & 1) == 1; +}; + +template +struct value_if { + static const T value = Then::value; +}; + +template +struct value_if { + static const T value = Else::value; +}; + +template +struct value_identity { + static const T value = x; +}; + +template class Fn, T x, T p=0> +struct inverse { + static const T value = + value_if::value == x, T, + value_identity, inverse >::value; +}; + +struct null_type{}; + +template +struct cons_c { + static const T head = i; + typedef Tail tail; +}; + +template +struct static_range { + static const int head = k; + typedef static_range tail; +}; + +template +struct static_range { + static const int head = f; + typedef null_type tail; +}; + +template +struct enable_if { + typedef T type; +}; + +template +struct enable_if {}; + +template +struct size_multiple_power_of_two { + static const bool value = (sizeof(T) & ((1 << p) - 1)) == 0; +}; + + +} diff --git a/nemo-retriever-ocr/cpp/trove/trove/warp.h b/nemo-retriever-ocr/cpp/trove/trove/warp.h new file mode 100644 index 0000000000000000000000000000000000000000..1f31d576c45b6ee8485b05e9e2310c6e740f81a0 --- /dev/null +++ b/nemo-retriever-ocr/cpp/trove/trove/warp.h @@ -0,0 +1,48 @@ +/* +Copyright (c) 2013, NVIDIA Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#pragma once + +namespace trove { + +enum { + WARP_SIZE = 32, + WARP_MASK = 0x1f, + WARP_CONVERGED = 0xFFFFFFFF, + LOG_WARP_SIZE = 5 +}; + +__device__ +inline bool warp_converged() { +#if defined(CUDART_VERSION) && CUDART_VERSION >= 9000 + return (__activemask() == WARP_CONVERGED); +#else + return (__ballot(true) == WARP_CONVERGED); +#endif +} + +} diff --git a/nemo-retriever-ocr/hatch_build.py b/nemo-retriever-ocr/hatch_build.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa9d85a649e23e410eb8994c2e3b38a899b2eae --- /dev/null +++ b/nemo-retriever-ocr/hatch_build.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import sys +import subprocess +from pathlib import Path + +from hatchling.builders.hooks.plugin.interface import BuildHookInterface + + +def _extension_up_to_date(project_root: Path) -> bool: + """Return True if a built .so exists and is newer than all sources. + + Respects the following directories: + - src/nemo_retriever_ocr_cpp (Python shim and built .so location) + - cpp/ (C++/CUDA sources) + - scripts/ (build script) + """ + extension_dir = project_root / "src" / "nemo_retriever_ocr_cpp" + candidates = list(extension_dir.glob("_nemo_retriever_ocr_cpp*.so")) + if not candidates: + return False + + newest_so_mtime = max(p.stat().st_mtime for p in candidates) + + newest_src_mtime = 0.0 + for directory in (project_root / "cpp", project_root / "scripts", extension_dir): + if not directory.exists(): + continue + for path in directory.rglob("*"): + if not path.is_file(): + continue + if path.suffix in {".cu", ".cpp", ".cuh", ".h", ".py"}: + mtime = path.stat().st_mtime + if mtime > newest_src_mtime: + newest_src_mtime = mtime + + return newest_so_mtime >= newest_src_mtime + + +class CustomBuildHook(BuildHookInterface): + def initialize(self, version: str, build_data: dict) -> None: + project_root = Path(__file__).parent + script_path = project_root / "scripts" / "build-extension.py" + + env = os.environ.copy() + # Ensure the extension actually builds during package build + env.setdefault("BUILD_CPP_EXTENSION", "1") + + # Allow users to force rebuild or skip if up-to-date + force_rebuild = env.get("BUILD_CPP_FORCE", "0") == "1" + build_enabled = env.get("BUILD_CPP_EXTENSION", "1") == "1" + + if build_enabled and not force_rebuild and _extension_up_to_date(project_root): + # Cached build found and sources unchanged; skip rebuild + return + + subprocess.run( + [ + os.fspath(sys.executable), + os.fspath(script_path), + ], + cwd=os.fspath(project_root), + env=env, + check=True, + ) + + diff --git a/nemo-retriever-ocr/pyproject.toml b/nemo-retriever-ocr/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..49e733131701324b30fc4651a48672fb17189c03 --- /dev/null +++ b/nemo-retriever-ocr/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "nemo-retriever-ocr" +version = "1.0.0" +description = "NeMo Retriever OCR" +authors = [{ name = "NVIDIA NeMo Retriever" }] +requires-python = ">=3.12,<3.13" +dependencies = [ + "pandas>=2.3.3", + "pillow>=12.0.0", + "scikit-learn>=1.7.2", + "shapely>=2.1.2,<3", + "torch>=2.8.0", + "torchvision>=0.23.0", +] + +[project.urls] +Homepage = "https://nvidia.com/" + +[dependency-groups] +dev = [ + "ipython>=9.6.0", + "pre-commit>=3.8.0,<4", + "pytest>=8.3.2,<9", + "pytest-cov>=5.0.0,<6", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = [ + "src/nemo_retriever_ocr", + "src/nemo_retriever_ocr_cpp", +] + +[tool.hatch.build.targets.wheel.hooks.custom] +path = "hatch_build.py" +dependencies = ["setuptools>=68", "torch>=2.0"] + +[tool.hatch.build.targets.sdist] +include = [ + "src/**", + "cpp/**", + "scripts/**", + "hatch_build.py", + "pyproject.toml", + "README.md", + "LICENSE*", +] diff --git a/nemo-retriever-ocr/scripts/build-extension.py b/nemo-retriever-ocr/scripts/build-extension.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d1af2a9e64a8b0275ccf00ca85b2eb934a17e7 --- /dev/null +++ b/nemo-retriever-ocr/scripts/build-extension.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import os +import sys +import shutil +from glob import glob + +from pathlib import Path + +def _parse_arch_list(arch_list: str) -> list[str]: + """Parse TORCH_CUDA_ARCH_LIST-formatted string into nvcc -gencode flags. + + Accepts tokens like "7.5", "8.6", "8.6+PTX", "86", "90+PTX" separated by + spaces, semicolons, or commas. + """ + tokens = arch_list.replace(";", " ").replace(",", " ").split() + flags: list[str] = [] + for token in tokens: + with_ptx = token.endswith("+PTX") + clean = token[:-4] if with_ptx else token + clean = clean.replace(".", "") + if not clean.isdigit(): + continue + mm = clean + flags.append(f"-gencode=arch=compute_{mm},code=sm_{mm}") + if with_ptx: + flags.append(f"-gencode=arch=compute_{mm},code=compute_{mm}") + return flags + + +def _detect_arch_list() -> list[str]: + """Best-effort detection of a single architecture if a GPU is visible. + + Falls back to a conservative list (sm_75 and sm_86) when detection isn't possible + (common during Docker image build). + """ + try: + import torch # type: ignore + + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability(0) + mm = f"{major}{minor}" + return [f"-gencode=arch=compute_{mm},code=sm_{mm}"] + except Exception: + pass + + # Fallback: reasonably modern GPUs supported by CUDA 12.x toolchains + return [ + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_86,code=sm_86", + ] + + +# Decide architectures in this order of precedence: +# 1) Respect TORCH_CUDA_ARCH_LIST if provided +# 2) Detect from the visible GPU at runtime +# 3) Fallback to a safe default list +arch_env = os.environ.get("TORCH_CUDA_ARCH_LIST", "").strip() +cuda_architectures = _parse_arch_list(arch_env) if arch_env else _detect_arch_list() + +includes = [ + "-I/usr/local/cuda/include", + f"-I{os.getcwd()}/cpp/trove", +] + +libs = [] + +common_args = ["-std=c++17"] +gcc_args = ["-fopenmp", "-DOMP_NESTED=true"] +cuda_args = [] + +# TODO: analyze if SIMD optimizations are beneficial on ARM +if os.environ.get("ARCH") != "arm64": + gcc_args += ["-mavx2"] + +files = sorted(glob("cpp/**/*.cu", recursive=True) + glob("cpp/**/*.cpp", recursive=True)) +files = [f for f in files if "trove/tests" not in f] + +# debug=True +debug = False +if debug: + common_args += ["-g", "-O0"] + cuda_args += ["-G"] +else: + common_args += ["-O3", "-DNDEBUG"] + +compile_args = { + "cxx": common_args + gcc_args + includes, + "nvcc": common_args + cuda_args + includes + cuda_architectures, +} + + +def build() -> None: + if os.environ.get("BUILD_CPP_EXTENSION", "1") != "1": + print("Environment variable BUILD_CPP_EXTENSION=1 not set. Skipping build.") + sys.exit(0) + + from setuptools import Distribution + from torch.utils.cpp_extension import CUDAExtension, BuildExtension + + ext_modules = [ + CUDAExtension( + "_nemo_retriever_ocr_cpp", + files, + extra_compile_args=compile_args, + libraries=libs, + ) + ] + + distribution = Distribution({"name": "nemo_retriever_ocr_cpp", "ext_modules": ext_modules}) + + build_ext = BuildExtension.with_options(parallel=False) + cmd = build_ext(distribution) + cmd.ensure_finalized() + cmd.run() + + # Copy built extensions back to the project + for output in cmd.get_outputs(): + output = Path(output) + relative_extension = Path("src/nemo_retriever_ocr_cpp") / output.relative_to(cmd.build_lib) + + shutil.copyfile(output, relative_extension) + mode = os.stat(relative_extension).st_mode + mode |= (mode & 0o444) >> 2 + os.chmod(relative_extension, mode) + + +if __name__ == "__main__": + build() diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/__init__.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb59f61a04031738fb505e42b3f5c821c9f54ad5 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/__init__.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb59f61a04031738fb505e42b3f5c821c9f54ad5 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/base.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..963428be96800818044b6b1e14e84a3452e0ba81 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/base.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Target encoder.""" + +import collections +import logging +import warnings + +import torch +from nemo_retriever_ocr.inference.post_processing.data.text_region import Batch, TextRegion +from nemo_retriever_ocr.inference.post_processing.data.worker_messages import TargetEncoderMessage + +from nemo_retriever_ocr.inference.models.utils import is_named_tuple + +_PREPARED_KEY = "_prepared_base" + +logger = logging.getLogger(__name__) + + +@torch.jit.script +def are_verts_outside( + vertices: torch.Tensor, x_max: float, y_max: float, x_min: float = 0, y_min: float = 0 +): + x = vertices[:, 0] + y = vertices[:, 1] + are_outside = torch.logical_or( + torch.logical_or(x < x_min, x > x_max), torch.logical_or(y < y_min, y > y_max) + ) + + return are_outside + + +class TargetEncoderBase(object): + """Class that handles encoding of targets and sending them to the gpu.""" + + def __init__(self, input_size, amp_opt, verbose=False): + """Initializes the target encoder.""" + self.input_size = input_size + self.amp_opt = amp_opt + self.verbose = verbose + + def prepare_data(self, batch: Batch, input_sizes): + """Operates on object_batch in a mutable manner.""" + if getattr(batch, _PREPARED_KEY, None) is not None: + return + + for example, input_size in zip(batch, input_sizes): + im_width = input_size[-1] + im_height = input_size[-2] + + coal = example.coalesce_homogeneous() + are_outside = are_verts_outside(coal, im_width, im_height) + + offset = 0 + for r_i in range(len(example)): + region = example[r_i] + num_vertices = region.region.vertices.shape[0] + if not region.valid: + offset += num_vertices + continue + + tr_outside = are_outside[offset : offset + num_vertices] + offset += num_vertices + + any_outside = torch.any(tr_outside).item() + + # If it straddles the boundary, then mark it as invalid so that the + # net doesn't get penalized either way + if any_outside: + region.valid = False + + setattr(batch, _PREPARED_KEY, True) + + def _convert_labels_to_targets(self, object_batch, input_sizes): + """Place holder for labels to target conversion.""" + raise NotImplementedError("Subclasses must implement this function!") + + def convert_labels_to_targets(self, object_batch, input_sizes): + self.prepare_data(object_batch, input_sizes) + + return self._convert_labels_to_targets(object_batch, input_sizes) + + def send_targets_to_gpu(self, targets, **kwargs): + """Sends targets to the gpu.""" + if torch.is_tensor(targets): + if targets.numel() > 0: + r = targets.cuda(**kwargs) + else: + r = torch.empty(*targets.shape, dtype=targets.dtype, device="cuda") + return r + + if isinstance(targets, str): + return targets + + # Look for a dict type object + if isinstance(targets, collections.abc.Mapping): + return { + k: self.send_targets_to_gpu(v, **kwargs) + for k, v in targets.items() + if k != "__other" + } + + # Look for a namedtuple + if is_named_tuple(targets): + return type(targets)(*[self.send_targets_to_gpu(t, **kwargs) for t in targets]) + + if isinstance(targets, collections.abc.Iterable): + return type(targets)([self.send_targets_to_gpu(t, **kwargs) for t in targets]) + + # Nothing that can be sent to the GPU, so just return it back + return targets + + def convert_targets_to_labels( + self, target_dict, image_size, limit_idxs=None, is_gt=True, **kwargs + ) -> Batch: + raise NotImplementedError("Subclasses must implement this function!") + + def is_recognition(self): + return False + + def get_charset(self): + raise ValueError("This target encoder does not support charsets!") + + def handle_message(self, message: TargetEncoderMessage): + warnings.warn(f"No known message can handle type {type(message)}!") + + def send_messages(self, worker_comm_context, gpu_targets, **kwargs): + pass + + def get_state_dict(self): + state_dict = dict() + self.build_state_dict(state_dict) + return state_dict + + def build_state_dict(self, state_dict): + pass + + def load_state_dict(self, state_dict): + pass + + def get_rrect_and_quad(self, region: TextRegion): + verts = region.region.vertices + if verts.shape[0] < 4: + verts = torch.cat([verts, verts[:1].expand(4 - verts.shape[0], 2)]) + try: + rrect = region.region.min_rrect + except: # noqa: E722 + rrect = region.region.bounds_quad + region.valid = False + vtx_count = verts.shape[0] + + quad = verts if vtx_count == 4 else rrect + + return rrect, quad diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/recognizer_encoder.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/recognizer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b51d30292243ca3bbf0ca4a0f52faee20c2ed9d3 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/recognizer_encoder.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +"""Base target encoder for e2e recognition models.""" + +from collections import defaultdict, deque +import logging +import math +import os +from typing import Tuple, List, Optional, Callable, Dict + +import numpy as np +import torch + +import nemo_retriever_ocr.inference.post_processing.data.text_region as tr +from nemo_retriever_ocr.inference.post_processing.data.quadrangle import Quadrangle +from nemo_retriever_ocr.inference.post_processing.data.worker_messages import TargetEncoderMessage +from nemo_retriever_ocr.inference.encoders.base import TargetEncoderBase + +from nemo_retriever_ocr_cpp import ( + beam_decode, + sparse_select, + create_sbo_lm, + decode_sequences, + create_token_mapping, +) +from nemo_retriever_ocr.inference.models.utils import ( + f_measure, + tensor_all_reduce, + tensor_all_gather, +) + +logger = logging.getLogger(__name__) +logging.getLogger("shapely.geos").setLevel(logging.FATAL) + + +# Index 0 is Blank +# Index 1 is EOS +# Index 2 is 'Rare' +_NUM_SPECIAL = 3 + +# The maximum number of regions per batch that will be trained +MAX_REGIONS = 96 + + +class UpdateTrainedStatsMessage(TargetEncoderMessage): + def __init__(self, name, buffer: torch.Tensor): + super().__init__(name) + self.buffer = buffer + + def build_state(self, state): + super().build_state(state) + state["buffer"] = self.buffer + + +class RecognitionTargetEncoder(TargetEncoderBase): + def __init__( + self, + charset: str, + input_size, + sequence_length: int, + amp_opt=0, + combine_duplicates=False, + is_train=True, + lm_path=None, + verbose=False, + ): + super().__init__(input_size, amp_opt, verbose) + + self.sequence_length = sequence_length + self.combine_duplicates = combine_duplicates + self.is_train = is_train + + logger.info("Combine duplicates: {}".format(combine_duplicates)) + + self.charset = charset + self.lm_path = lm_path + self.cpp_token_mapping = None + + self._initialized = False + + self.beam_lm = None + + self.send_buffers = None + + def _initialize(self): + if self._initialized: + return + self._initialized = True + + self.idx_to_char = {i + _NUM_SPECIAL: c for i, c in enumerate(self.charset)} + + self.char_to_idx = {c: i + _NUM_SPECIAL for i, c in enumerate(self.charset)} + + if self.lm_path is not None: + if not os.path.exists(self.lm_path): + raise ValueError(f"The language model path '{self.lm_path}' doesn't exist!") + self.beam_lm = create_sbo_lm(self.lm_path, self.idx_to_char) + + self.cpp_token_mapping = create_token_mapping(self.idx_to_char) + + def __getstate__(self): + ret = dict(self.__dict__) + if self._initialized: + del ret["cpp_token_mapping"] + del ret["idx_to_char"] + del ret["char_to_idx"] + del ret["beam_lm"] + del ret["send_buffers"] + ret["_initialized"] = False + + return ret + + @property + def charset_size(self): + return _NUM_SPECIAL + len(self.charset) + + def is_recognition(self): + return True + + def get_charset(self): + return self.charset + + def cb_convert_labels_to_targets( + self, batch: tr.Batch, input_sizes, handle_region: Callable[[int, int, tr.TextRegion], None] + ): + self._initialize() + + # classes = [] + # masks = [] + region_counts = [] + geo_idxs = [] + word_lens = [] + word_use_counts = [] + + additional = self.get_additional_regions(batch, input_sizes) + + # Get an upper bound for the number of regions, and the maximum text length + loose_max_seq_len = 0 + loose_num_regions = 0 + for ex_idx, example in enumerate(batch): + loose_num_regions += len(additional[ex_idx]) + for region in example: + if region.valid: + loose_max_seq_len = max(loose_max_seq_len, len(region.text)) + loose_num_regions += 1 + + class_tensor = torch.empty(loose_num_regions, loose_max_seq_len + 1, dtype=torch.int64) + mask_tensor = torch.empty(loose_num_regions, loose_max_seq_len + 1, dtype=torch.float32) + + max_seq_len = 0 + enc_offset = 0 + for ex_idx, example in enumerate(batch): + regions = list(example.regions) + + if self.is_train: + regions.extend(additional[ex_idx]) + + used_regions = [] + for r_idx, region in enumerate(regions): + valid = self.is_region_valid(region, used_regions) + + if not valid: + continue + + text = region.text + + max_seq_len = max(max_seq_len, len(text) + 1) + enc_offset += 1 + + word_lens.append(len(text)) + + geo_idxs.append((ex_idx, r_idx)) + + handle_region(ex_idx, r_idx, region) + + used_regions.append(region) + region_counts.append(len(used_regions)) + + class_tensor = class_tensor[:enc_offset, :max_seq_len] + mask_tensor = mask_tensor[:enc_offset, :max_seq_len] + + region_counts = torch.tensor(region_counts, dtype=torch.int64) + word_lens = torch.tensor(word_lens, dtype=torch.int32) + + if geo_idxs: + geo_idxs = torch.tensor(geo_idxs, dtype=torch.int64) + else: + geo_idxs = torch.empty(0, 2, dtype=torch.int64) + + return { + "sequences": class_tensor, + "mask": mask_tensor, + "region_counts": region_counts, + "geo_idxs": geo_idxs, + } + + def get_additional_regions(self, batch: tr.Batch, input_sizes): + additional = [[] for _ in batch] + if self.is_train: + num_regions = sum(len(ex) for ex in batch) + dummy_quads = self.create_dummy_quads(input_sizes[0], num_regions, len(batch)) + # Weight the probability of an example by the inverse of the number of regions. + # Effectively, this means that examples with fewer regions are more likely + # to get a dummy example + probs = [(1 / len(ex)) if len(ex) > 0 else 2 for ex in batch] + t_prob = sum(probs) + probs = [p / t_prob for p in probs] + assignments = np.random.choice(len(batch), size=len(dummy_quads), p=probs) + for dummy, assign in zip(dummy_quads, assignments): + additional[assign].append(dummy) + return additional + + def is_region_valid( + self, region: tr.TextRegion, used_regions: Optional[List[tr.TextRegion]] = None + ): + if not region.valid: + return False + + if region.region.vertices.shape[0] < 4: + return False + + try: + region_area = region.region.area + if region_area < 1: + return False + except: # noqa: E722 + return False + + valid = True + valid = valid and all(c in self.char_to_idx for c in region.text) + valid = valid and getattr(region, "recog_valid", True) + + # This is one of our dummy / negative regions. For this, just ensure that it doesn't + # intersect with other valid regions + if valid and region.text == "" and used_regions is not None: + for used_region in used_regions: + if used_region.region._poly.intersects(region.region._poly): + valid = False + break + + return valid + + def limit_regions( + self, + targets: Dict[str, torch.Tensor], + select_fn: Callable[[torch.Tensor, torch.Tensor], None], + max_regions=MAX_REGIONS, + ): + in_region_counts = targets["region_counts"] + word_use_counts = targets["word_use_counts"] + + # Partly as a training optimization, and partly to ensure that we have a bounded + # memory envelope, train recognition with an upper bounded number of regions. + # To do this, we sample from the full set of regions. + if not self.is_train or word_use_counts.shape[0] <= max_regions: + return + + inv_uses = 1 / word_use_counts.float() + + sel_indices = torch.multinomial(inv_uses, max_regions, replacement=False) + sel_indices = torch.sort(sel_indices).values + + key_set = ["sequences", "mask", "geo_idxs"] + + in_buffers = [targets[k] for k in key_set] + + out_region_counts, out_buffers = sparse_select(in_region_counts, in_buffers, sel_indices) + + select_fn(in_region_counts, sel_indices) + + for k, v in zip(key_set, out_buffers): + targets[k] = v + targets["region_counts"] = out_region_counts + + def cb_convert_targets_to_labels( + self, + target_dict: Dict[str, torch.Tensor], + image_size, + limit_idxs: Optional[torch.Tensor], + is_gt, + subsel_fn: Optional[Callable[[int, int, int], Dict[str, torch.Tensor]]], + geometry_fn: Callable[[Dict, int, int, int], torch.Tensor], + **kwargs, + ): + self._initialize() + + target_dict = self.subselect_targets(target_dict, limit_idxs, subsel_fn) + + sequences = target_dict["sequences"].cpu() + region_counts = target_dict["region_counts"].cpu() + confidence = target_dict.get("confidence", None) + if confidence is not None: + confidence = confidence.cpu() + + decoded_seq_probs = None + combine_duplicates = not is_gt and self.combine_duplicates + language_model = self.beam_lm if not is_gt else None + if sequences.dim() == 3: + if sequences.shape[0] > 0: + decoded_seq_ids, decoded_seq_probs, combine_duplicates = self.convert_preds_to_idxs( + sequences, combine_duplicates, language_model + ) + else: + decoded_seq_ids = torch.empty( + 0, sequences.shape[1], dtype=torch.int64, device=sequences.device + ) + elif sequences.dim() == 2: + decoded_seq_ids = sequences + else: + raise ValueError("Unsupported sequence tensor!") + + decoded_strings = decode_sequences( + decoded_seq_ids, self.cpp_token_mapping, decoded_seq_probs + ) + + examples = [] + offset = 0 + for ex_idx, region_count in enumerate(region_counts): + region_count = region_count.item() + + regions = [] + for i in range(region_count): + text, text_conf = decoded_strings[offset] + region_conf = confidence[offset].item() if confidence is not None else 1 + geo = geometry_fn(target_dict, ex_idx, i, offset) + offset += 1 + + overall_conf = f_measure(region_conf, text_conf) + + region = tr.TextRegion( + Quadrangle(geo), text, valid=len(text) > 0 and overall_conf > 0.5 + ) + region.quad_prob = region_conf + region.text_prob = text_conf + region.confidence = overall_conf + regions.append(region) + + examples.append(tr.Example(regions)) + + return tr.Batch(examples) + + def subselect_targets( + self, + target_dict: Dict[str, torch.Tensor], + limit_idxs: torch.Tensor, + limit_fn: Optional[Callable[[int, int, int], Dict[str, torch.Tensor]]] = None, + ): + if limit_idxs is None: + return target_dict + + sequences = target_dict["sequences"].cpu() + region_counts = target_dict["region_counts"].cpu() + geo_idxs = target_dict["geo_idxs"].cpu() + confidence = target_dict.get("confidence", None) + if confidence is not None: + confidence = confidence.cpu() + + new_seqs = [] + new_counts = [] + new_confidence = [] + new_geo_idxs = [] + other_limits = defaultdict(lambda: []) + cs_region_counts = torch.cumsum(region_counts, 0) + for limit_idx in limit_idxs: + limit_idx = limit_idx.item() + start_offset = cs_region_counts[limit_idx - 1].item() if limit_idx > 0 else 0 + end_offset = cs_region_counts[limit_idx].item() + new_seqs.append(sequences[start_offset:end_offset]) + new_geo_idxs.append(geo_idxs[start_offset:end_offset]) + + if limit_fn is not None: + others = limit_fn(limit_idx, start_offset, end_offset) + for k, v in others.items(): + other_limits[k].append(v) + + if confidence is not None: + new_confidence.append(confidence[start_offset:end_offset]) + new_counts.append(region_counts[limit_idx].item()) + + sequences = torch.cat(new_seqs) + geo_idxs = torch.cat(new_geo_idxs) + if confidence is not None: + confidence = torch.cat(new_confidence) + region_counts = torch.tensor(new_counts, dtype=torch.int64) + for k, v in other_limits.items(): + other_limits[k] = torch.cat(v, dim=0) + + ret = {k: v for k, v in target_dict.items()} + ret.update( + sequences=sequences, + region_counts=region_counts, + geo_idxs=geo_idxs, + confidence=confidence, + ) + ret.update(other_limits) + + return ret + + def create_dummy_quads(self, input_size, num_curr_quads, batch_size): + # num_quads = max(1, min(num_curr_quads // 10, 2 * batch_size)) + num_quads = batch_size + + quads = [] + for _ in range(num_quads): + # Sample a centerpoint from the inner 3/4 of the image + center = np.random.rand(2) + center[0] = center[0] * (3 * input_size[-1] / 4) + (input_size[-1] / 8) + center[1] = center[1] * (3 * input_size[-2] / 4) + (input_size[-2] / 8) + + # Sample an angle in the range of +/- pi/4 + angle = math.pi * (np.random.rand() * 2 - 1) / 4 + + rot_mat = np.array( + [[math.cos(angle), math.sin(angle)], [-math.sin(angle), math.cos(angle)]] + ).T + + w = np.random.rand() * (input_size[-1] / 8) + h = np.random.rand() * (input_size[-2] / 8) + + vecs = np.array( + [ + [-w, -h], + [w, -h], + [w, h], + [-w, h], + ] + ) + + vecs = vecs.dot(rot_mat) + + vecs += center[None, :] + + # Clamp the quad to be within the image + vecs[:, 0] = np.minimum(np.maximum(vecs[:, 0], 0), input_size[-1]) + vecs[:, 1] = np.minimum(np.maximum(vecs[:, 1], 0), input_size[-2]) + + quads.append(torch.from_numpy(vecs)) + + rand_coords = torch.stack(quads).float() + + return [tr.TextRegion(Quadrangle(coords), "", valid=True) for coords in rand_coords] + + @staticmethod + def convert_preds_to_idxs( + seq: torch.Tensor, combine_duplicates=False, language_model=None + ) -> Tuple[torch.Tensor, torch.Tensor, bool]: + """ + Converts a prediction distribution to the set of preferred sequences. + seq: BxTxC, where B=batch, T=timestep, C=char + Returns: Tuple[indices,probs] + """ + + if combine_duplicates or language_model is not None: + ###### CTC + output, scores = beam_decode( + seq, + 100, + lang_model=language_model, + lm_weight=1, + combine_duplicates=combine_duplicates, + ) + ###### + else: + ###### Max + scores, output = torch.max(seq, dim=2) + ###### + + return output, scores, False + + def decode_sequence( + self, seq: torch.Tensor, remove_duplicates=False, probs: torch.Tensor = None + ) -> Tuple[str, float]: + self._initialize() + + text = "" + prev = None + prob = 0 + for i, tok_idx in enumerate(seq): + tok_idx = tok_idx.item() + if tok_idx == prev and remove_duplicates: + continue + prev = tok_idx + + if tok_idx != 1 and probs is not None and probs.dim() == 1: + tok_prob = math.log(probs[i].item()) + + prob += tok_prob + + if tok_idx == 0: + continue + # text += '_' + elif tok_idx == 1: + break + elif tok_idx == 2: + text += "^" + else: + text += self.idx_to_char[tok_idx] + + prob = math.exp(prob) + if probs is not None and probs.dim() == 0: + prob = probs.item() + # logger.info(f'Sequence: {text} - {prob}') + return text, prob + + def send_messages(self, worker_comm_context, gpu_targets, name, **kwargs): + recog_sequences = gpu_targets["sequences"] + + # Figure out a shape envelope for all of the sequences + shape_tensor = torch.tensor(recog_sequences.shape, dtype=torch.int64, device="cpu") + shape_tensor = tensor_all_reduce(shape_tensor, torch.distributed.ReduceOp.MAX) + pad_sequences = torch.ones( + *shape_tensor.tolist(), dtype=recog_sequences.dtype, device=recog_sequences.device + ) + pad_sequences[: recog_sequences.shape[0], : recog_sequences.shape[1]] = recog_sequences + pad_sequences[recog_sequences.shape[0] :, 0] = -1 + + pad_sequences = tensor_all_gather(pad_sequences) + vmask = pad_sequences[:, 0] != -1 + pad_sequences = pad_sequences[vmask].short() + + if self.send_buffers is None: + self.send_buffers = deque(None for _ in range(worker_comm_context.num_workers + 1)) + + def _convert_labels_to_targets(self, batch: tr.Batch, input_sizes): + quads = [] + rrects = [] + + def handle_region(ex_idx: int, r_idx: int, region: tr.TextRegion): + rrect = region.region.min_rrect + vtx_count = region.region.vertices.shape[0] + if vtx_count == 4: + quads.append(region.region.vertices) + else: + quads.append(rrect) + rrects.append(rrect) + + ret = self.cb_convert_labels_to_targets(batch, input_sizes, handle_region) + + classes = ret["sequences"] + if classes.shape[0] > 0: + quad_tensor = torch.stack(quads) + rrect_tensor = torch.stack(rrects) + else: + quad_tensor = torch.empty(0, 4, 2, dtype=torch.float32) + rrect_tensor = torch.empty(0, 4, 2, dtype=torch.float32) + + ret.update(quads=quad_tensor, rboxes=rrect_tensor) + + def handle_select(region_counts: torch.Tensor, sel_indices: torch.Tensor): + key_set = ["quads", "rboxes"] + in_buffers = [ret[k] for k in key_set] + + _, out_buffers = sparse_select(region_counts, in_buffers, sel_indices) + + for k, v in zip(key_set, out_buffers): + ret[k] = v + + self.limit_regions(ret, handle_select) + + ret["trained_quads"] = ret["quads"].clone() + + return ret + + def convert_targets_to_labels( + self, target_dict, image_size, limit_idxs=None, is_gt=True, **kwargs + ): + def subsel_quads(limit_idx: int, start_offset: int, end_offset: int): + return {"quads": target_dict["quads"][start_offset:end_offset]} + + def get_quad(target_dict: Dict[str, torch.Tensor], ex_idx: int, r_idx: int, r_offset: int): + return target_dict["quads"][r_offset].cpu() + + return self.cb_convert_targets_to_labels( + target_dict, image_size, limit_idxs, is_gt, subsel_fn=subsel_quads, geometry_fn=get_quad + ) diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/relational_encoder.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/relational_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..74624d578574f4f15121e3e43d2bb5ea258a5789 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/encoders/relational_encoder.py @@ -0,0 +1,465 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +"""Multiple prior target encoder.""" + +import logging +import math + +import torch + +from nemo_retriever_ocr.inference.post_processing.data.quadrangle import Quadrangle +import nemo_retriever_ocr.inference.post_processing.data.text_region as tr +from nemo_retriever_ocr.inference.encoders.base import TargetEncoderBase +from nemo_retriever_ocr.inference.models.utils import cat + +from nemo_retriever_ocr_cpp import dense_relations_to_graph as cpp_dense_relations_to_graph + +logger = logging.getLogger(__name__) +logging.getLogger("shapely.geos").setLevel(logging.FATAL) + + +class RelationalTargetEncoder(TargetEncoderBase): + def __init__(self, input_size, amp_opt=0, is_train=True): + super().__init__(input_size, amp_opt, False) + + self.is_train = is_train + + def _convert_labels_to_targets(self, batch: tr.Batch, input_sizes): + all_relations = [] + all_line_relations = [] + all_weights = [] + all_region_valid = [] + all_relation_valid = [] + all_quads = [] + all_rrects = [] + all_ex_valid = [] + all_w_idx_to_line_map = [] + all_line_to_line_map = [] + geo_idxs = [] + r_offset = [] + region_counts = [] + + for ex_idx, example in enumerate(batch): + r_offset.append(len(geo_idxs)) + graph = example.relation_graph + + all_ex_valid.append(graph is not None) + + ( + relations, + region_valid, + relation_valid, + line_relations, + w_idx_to_line_map, + line_to_line_map, + ) = self.encode_relations(example, graph) + + all_relations.append(relations) + all_region_valid.append(region_valid) + all_relation_valid.append(relation_valid) + all_line_relations.append(line_relations) + all_w_idx_to_line_map.append(w_idx_to_line_map) + all_line_to_line_map.append(line_to_line_map) + region_counts.append(len(example)) + + quads = [] + rrects = [] + for r_idx, region in enumerate(example): + rrect, quad = self.get_rrect_and_quad(region) + + quads.append(quad) + rrects.append(rrect) + geo_idxs.append((ex_idx, r_idx)) + + if not quads: + quads = torch.empty(0, 4, 2, dtype=torch.float32) + rrects = torch.empty(0, 4, 2, dtype=torch.float32) + else: + quads = torch.stack(quads) + rrects = torch.stack(rrects) + + all_quads.append(quads) + all_rrects.append(rrects) + + def get_offsets(rels): + offsets = [0] + for rel in rels: + offsets.append(offsets[-1] + (0 if rel is None else rel.numel())) + return offsets + + offsets = get_offsets(all_relations) + + offsets = torch.tensor(offsets, dtype=torch.int64) + + all_ex_valid = torch.tensor(all_ex_valid, dtype=torch.bool) + + r_offset.append(len(geo_idxs)) + geo_idxs = torch.tensor(geo_idxs) + r_offset = torch.tensor(r_offset) + region_counts = torch.tensor(region_counts) + + all_quads = cat(all_quads, 4, 2) + all_rrects = cat(all_rrects, 4, 2) + + ret = { + "ex_valid": all_ex_valid, + "offsets": offsets, + "relations": all_relations, + "line_relations": all_line_relations, + "weights": all_weights, + "region_valid_mask": all_region_valid, + "relation_valid_mask": all_relation_valid, + "w_idx_to_line_map": all_w_idx_to_line_map, + "line_to_line_map": all_line_to_line_map, + "quads": all_quads, + "rboxes": all_rrects, + "ex_offsets": r_offset, + "region_counts": region_counts, + "geo_idxs": geo_idxs, + "trained_quads": all_quads.clone(), + } + + return ret + + def is_region_valid(self, region: tr.TextRegion): + if not region.valid: + return False + + if not getattr(region, "rel_valid", True): + return False + + if region.region.vertices.shape[0] < 4: + return False + + try: + region_area = region.region.area + + if region_area <= 5: + return False + except: # noqa: E722 + return False + + try: + _ = region.region.min_rrect + except: # noqa: E722 + return False + + return True + + def encode_relations(self, example: tr.Example, graph: tr.RelationGraph): + num_regions = len(example) + relations = torch.full((num_regions,), -1, dtype=torch.int64) + region_valid = torch.zeros(num_regions, dtype=torch.float32) + relation_valid = torch.empty(num_regions, num_regions, dtype=torch.float32) + line_relations = torch.zeros(num_regions, num_regions, dtype=torch.float32) + + num_lines = 0 + if graph is not None: + graph = graph.split_lines(example) + for para in graph.paragraphs: + num_lines += len(para) + + w_idx_to_line_map = torch.empty(num_regions, dtype=torch.int64) + line_to_line_map = torch.zeros(num_lines, num_lines, dtype=torch.float32) + + for i, region in enumerate(example): + valid = self.is_region_valid(region) + + if valid: + region_valid[i] = 1.0 + + if graph is not None: + # Invalidate all relations where the "from" is not valid + relation_valid = region_valid[:, None] * region_valid[None, :] + + line_idx = 0 + for paragraph in graph.paragraphs: + prev_sentence = None + for sentence in paragraph: + if prev_sentence is not None: + for prev_word in prev_sentence: + for curr_word in sentence: + line_relations[prev_word, curr_word] = 1.0 + line_to_line_map[line_idx - 1, line_idx] = 1.0 + + prev_word = None + for curr_word in sentence: + if prev_word is not None: + relations[prev_word] = curr_word + w_idx_to_line_map[curr_word] = line_idx + prev_word = curr_word + prev_sentence = sentence + line_idx += 1 + # Note: The `relation_valid` mask for lines is the outer product of region_valid and itself + else: + relation_valid.fill_(0) + + return ( + relations, + region_valid, + relation_valid, + line_relations, + w_idx_to_line_map, + line_to_line_map, + ) + + def convert_targets_to_labels( + self, target_dict, image_size, limit_idxs=None, is_gt=True, **kwargs + ): + all_word_relations = target_dict["relations"] + all_line_relations = target_dict["line_relations"] + all_line_unc = target_dict.get("line_rel_var", None) + region_counts = target_dict["region_counts"].cpu() + all_quads = target_dict["quads"] + + # These are ground truth. Convert them to dense form + if all_word_relations[0].dim() == 1: + all_word_relations = [sparse_to_dense(gt_rel) for gt_rel in all_word_relations] + + all_word_relations = [r.cpu() if r is not None else r for r in all_word_relations] + all_line_relations = [r.cpu() if r is not None else r for r in all_line_relations] + all_line_unc = ( + [r.cpu() if r is not None else r for r in all_line_unc] + if all_line_unc is not None + else None + ) + region_counts = region_counts.cpu() + all_quads = all_quads.cpu() + cs_region_counts = torch.cumsum(region_counts, 0) + + examples = [] + for i, (word_relations, line_relations) in enumerate( + zip(all_word_relations, all_line_relations) + ): + start_offset = cs_region_counts[i - 1] if i > 0 else 0 + end_offset = cs_region_counts[i] + quads = all_quads[start_offset:end_offset] + line_unc = all_line_unc[i] if all_line_unc is not None else None + + regions = [tr.TextRegion(Quadrangle(q), "") for q in quads] + graph = None + if end_offset > start_offset: + graph = self.dense_relations_to_graph( + word_relations, line_relations, line_unc, is_gt + ) + else: + graph = tr.RelationGraph() + + ex = tr.Example(regions, relation_graph=graph) + examples.append(ex) + + if limit_idxs is not None: + examples = [examples[idx] for idx in limit_idxs] + + return tr.Batch(examples) + + def sparse_relations_to_graph(self, relations: torch.Tensor): + relations = relations.cpu().tolist() + + in_chain = dict() + for from_idx, to_idx in enumerate(relations): + if to_idx == -1: + continue + + in_chain[to_idx] = (from_idx, 1) + + sentences = self.in_chain_to_groups(in_chain, len(relations)) + paragraphs = [[s] for s in sentences] + + return tr.RelationGraph(paragraphs) + + def dense_relations_to_graph( + self, + word_relations: torch.Tensor, + line_logits: torch.Tensor, + line_log_uncertainty: torch.Tensor = None, + is_gt=False, + ): + lines = [p[0] for p in cpp_dense_relations_to_graph(word_relations)] + + line_logits = line_logits.float() + + if is_gt: + null_conn = (line_logits.sum(dim=1, keepdim=True) == 0).to(line_logits.dtype) + line_logits = torch.cat((null_conn, line_logits), dim=1) + + if line_log_uncertainty is not None: + line_log_uncertainty = line_log_uncertainty.float() + inv_uncertainty = torch.exp(-line_log_uncertainty) + else: + inv_uncertainty = torch.ones_like(line_logits) + + w_idx_to_line_map = torch.empty(word_relations.shape[0], dtype=torch.int64) + + line_idx = 0 + for line in lines: + for word in line: + w_idx_to_line_map[word] = line_idx + line_idx += 1 + + valid_mask = line_logits != -math.inf + + null_w_idx_to_line_map = torch.cat( + (torch.tensor([0], dtype=w_idx_to_line_map.dtype), w_idx_to_line_map + 1), dim=0 + ) + + sa_w2l_idxs = null_w_idx_to_line_map.reshape(1, -1).expand(w_idx_to_line_map.shape[0], -1) + sa_l2l_idxs = w_idx_to_line_map.reshape(-1, 1).expand(-1, line_idx + 1) + + line_logits = torch.where(valid_mask, line_logits, torch.zeros_like(line_logits)) + inv_uncertainty = torch.where( + valid_mask, inv_uncertainty, torch.zeros_like(inv_uncertainty) + ) + + word_to_line_unc = torch.zeros( + w_idx_to_line_map.shape[0], line_idx + 1, dtype=line_logits.dtype + ) + word_to_line_unc.scatter_add_(dim=1, index=sa_w2l_idxs, src=inv_uncertainty) + + line_to_line_unc = torch.zeros(line_idx, line_idx + 1, dtype=line_logits.dtype) + line_to_line_unc.scatter_add_(dim=0, index=sa_l2l_idxs, src=word_to_line_unc) + + # The first index will give us the total of going from a word to a line, and the second gives us going from one word to another + unc_sums = line_to_line_unc[w_idx_to_line_map][:, null_w_idx_to_line_map].clamp_min(1e-6) + + unc_weights = inv_uncertainty / unc_sums + + w_logits = torch.where(valid_mask, unc_weights * line_logits, torch.zeros_like(line_logits)) + + word_to_line_logits = torch.zeros_like(word_to_line_unc) + word_to_line_logits.scatter_add_(dim=1, index=sa_w2l_idxs, src=w_logits) + + line_to_line_logits = torch.zeros_like(line_to_line_unc) + line_to_line_logits.scatter_add_(dim=0, index=sa_l2l_idxs, src=word_to_line_logits) + + self_mask = torch.full( + (line_to_line_logits.shape[0],), -math.inf, dtype=line_to_line_logits.dtype + ).diag() + self_mask = torch.cat( + (torch.zeros(self_mask.shape[0], 1, dtype=self_mask.dtype), self_mask), dim=1 + ) + + line_to_line_logits += self_mask + + line_to_line_probs = torch.softmax(line_to_line_logits, dim=1, dtype=torch.float32) + line_maxes = torch.max(line_to_line_probs, dim=1, keepdim=True).values + line_to_line_probs = torch.where( + line_to_line_probs == line_maxes, + torch.ones_like(line_to_line_probs), + torch.zeros_like(line_to_line_probs), + ) + line_to_line_probs = line_to_line_probs[:, 1:] + + rel_lines = set(tuple(p[0]) for p in cpp_dense_relations_to_graph(line_to_line_probs)) + + paragraphs = [] + for rel_line in rel_lines: + para = [lines[l_idx] for l_idx in rel_line] + paragraphs.append(para) + + return tr.RelationGraph(paragraphs) + + def construct_paragraphs(self, sentences, paragraphs): + word_to_sentence = dict() + for s in sentences: + for word_idx in s: + word_to_sentence[word_idx] = s + + new_paragraphs = [] + proc_ids = set() + for para in paragraphs: + new_para = [] + for word_idx in para: + s = word_to_sentence[word_idx] + + if s not in new_para: + new_para.append(s) + if all(proc_ids.isdisjoint(s) for s in new_para): + for s in new_para: + proc_ids.update(s) + new_paragraphs.append(new_para) + + return new_paragraphs + + def cvt_to_1_hot(self, relations: torch.Tensor): + if relations.dim() == 3: + return relations + + one_hot = torch.eye(3, dtype=torch.float32, device=relations.device) + + f_rel = relations.reshape(-1) + + oh_f_rel = one_hot[f_rel] + + oh_rel = oh_f_rel.reshape(*relations.shape, -1) + + return oh_rel + + def in_chain_to_groups(self, in_chain: dict, num_regions: int): + out_chain = {v[0]: k for k, v in in_chain.items()} + + processed = set() + groups = [] + for to_idx, (from_idx, conf) in in_chain.items(): + if to_idx in processed: + continue + + # Find the start of the chain + cycle_set = {from_idx} + is_cycle = False + while from_idx in in_chain: + to_idx = from_idx + from_idx = in_chain[to_idx][0] + if from_idx in cycle_set: + is_cycle = True + break + + # Completely ignore cycle chains + if is_cycle: + continue + + group = [from_idx] + processed.add(from_idx) + while to_idx in out_chain: + processed.add(to_idx) + group.append(to_idx) + to_idx = out_chain[to_idx] + group.append(to_idx) + processed.add(to_idx) + + groups.append(group) + + # Now add in the stragglers + for w_idx in range(num_regions): + if w_idx not in processed: + groups.append([w_idx]) + + return groups + + +def sparse_to_dense(sparse: torch.Tensor, handle_invalid=True, encode_null=False): + cols = sparse.shape[0] + if encode_null: + sparse = sparse + 1 + cols += 1 + + rel_valid_mask = None + ones = torch.ones(sparse.shape[0], dtype=torch.float32, device=sparse.device) + if handle_invalid: + rel_valid_mask = torch.where(sparse >= 0, ones, torch.zeros_like(ones)) + + if not encode_null: + sparse = sparse.clamp_min(0) + + dense = torch.zeros(sparse.shape[0], cols, dtype=torch.float32, device=sparse.device) + dense.scatter_(dim=1, index=sparse.unsqueeze(1), src=ones.unsqueeze(1)) + + if rel_valid_mask is not None: + dense *= rel_valid_mask[:, None] + return dense diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/__init__.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb59f61a04031738fb505e42b3f5c821c9f54ad5 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/blocks.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab8aa328a41499ba9ac02b893a6d10b5ed056c4 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/blocks.py @@ -0,0 +1,510 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Model blocks.""" + +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + + +class CReLU(nn.Module): + def __init__(self, act=F.relu): + super().__init__() + + self.act = act + + def forward(self, x): + x = torch.cat((x, -x), dim=1) + x = self.act(x) + return x + + +def get_activation(name): + """Returns a pytorch activation layer of type 'name', where 'name' is a string.""" + if isinstance(name, nn.Module): + return name + + if name == "relu": + return nn.ReLU() + if name == "elu": + return nn.ELU() + if name == "selu": + return nn.SELU() + if name == "sigmoid": + return nn.Sigmoid() + if name == "tanh": + return nn.Tanh() + if name == "softplus": + return nn.Softplus() + if name == "crelu": + return CReLU() + if name == "none": + return None + raise ValueError( + "Unsupported activation type: {}. " "Ensure activation name is all lower case.".format(name) + ) + + +def conv2d_block( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + activation="relu", + batch_norm=True, +): + """ + Returns pytorch two-dimensional convolutional layer with activation and batch_norm if requested. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): height and width of kernels. + stride (int): stride of the filters, default=1. + padding (int): padding added to input height x width, default=0. + dilation (int): dilation factor, default=1. + groups (int): number of convolution groups, default=1. + bias (bool): whether to use bias, default=True. + padding_mode (string): mode for applying padding, default='zeros'. + activation (string): type of activation, default='relu'. + batch_norm (bool): whether to use batch normalization, default=True. + + Returns: + conv2d_layer (nn.Sequential): pytorch two-dimensional convolution layer. + """ + items = [ + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ), + ] + + if batch_norm: + items.append(nn.BatchNorm2d(out_channels)) + + act = get_activation(activation) + if act: + items.append(act) + + return nn.Sequential(*items) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, activation="relu", batch_norm=True): + super().__init__() + + self.batch_norm = batch_norm + + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) + + self.downsample = None + if in_channels != out_channels: + self.downsample = nn.Conv2d(in_channels, out_channels, 1) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + self.act = get_activation(activation) + + def forward(self, x): + identity = x + + out = self.conv1(x) + if self.batch_norm: + out = self.bn1(out) + out = self.act(out) + + out = self.conv2(out) + if self.batch_norm: + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(identity) + + out = out + identity + + return self.act(out) + + +class Residual(nn.Module): + def __init__(self, inner): + super().__init__() + self.inner = inner + + def forward(self, x: torch.Tensor): + y = self.inner(x) + return x + y + + +def initialize_weights(model): + """Initializes the model weights.""" + for m in model.modules(): + if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)): + torch.nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="relu") + + +def get_no_bias_decay_params(model, l2_value): + """Returns weight decay parameters; l2_value set as the weight decay for layers with decay.""" + decay, no_decay = [], [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if name.endswith("bias"): + no_decay.append(param) + else: + decay.append(param) + + return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": l2_value}] + + +class GCContext(nn.Module): + def __init__(self, input_dim): + super().__init__() + + self.input_dim = input_dim + + self.key_proj = nn.Conv2d(input_dim, 1, 1) + + def forward(self, x): + # B,C,H,W + + attn = self.key_proj(x) / math.sqrt(self.input_dim) + # B,1,H,W + attn = attn.reshape(x.shape[0], 1, -1) + # B,1,HW + attn = F.softmax(attn, dim=-1, dtype=torch.float32) + attn = attn.reshape(x.shape[0], -1, 1) + # B,HW,1 + + rs_x = x.reshape(x.shape[0], x.shape[1], -1) + # B,C,HW + + focus = torch.bmm(rs_x, attn) + # B,C,1 + focus = focus.reshape(x.shape[0], x.shape[1], 1, 1) + # B,C,1,1 + + return focus + + +class GCTransform(nn.Module): + def __init__(self, input_dim, bottleneck_dim): + super().__init__() + + self.input_dim = input_dim + self.bottleneck_dim = bottleneck_dim + + self.conv_encode = nn.Conv2d(input_dim, bottleneck_dim, 1) + self.norm = nn.LayerNorm([bottleneck_dim, 1, 1]) + self.conv_decode = nn.Conv2d(bottleneck_dim, input_dim, 1) + + def forward(self, context): + encoded = self.conv_encode(context) + # B,R,1,1 + + encoded = self.norm(F.relu(encoded)) + + decoded = self.conv_decode(encoded) + # B,C,1,1 + + return decoded + + +class GCAttention(nn.Module): + def __init__(self, input_dim, bottleneck_dim): + super().__init__() + + self.context = GCContext(input_dim) + self.transform = GCTransform(input_dim, bottleneck_dim) + + def forward(self, x): + context = self.context(x) + + tx = self.transform(context) + + ret = x + tx + + return ret + + +class MAGCContext(nn.Module): + def __init__(self, input_dim, num_aspects): + super().__init__() + + if input_dim % num_aspects != 0: + raise ValueError("Number of aspects must evenly divide input_dim!") + + self.input_dim = input_dim + self.num_aspects = num_aspects + + self.split_size = input_dim // num_aspects + + self.key_proj = nn.Conv2d(input_dim, num_aspects, 1, groups=num_aspects) + + def forward(self, x): + # x: B,C,H,W + + # B,A,H,W + attn = self.key_proj(x) / math.sqrt(self.split_size) + # B,A,HW + attn = attn.reshape(attn.shape[0], attn.shape[1], -1) + attn_probs = F.softmax(attn, dim=2, dtype=torch.float32) + # B,A,1,HW + attn_probs = attn_probs.unsqueeze(dim=2) + + # B,A,C/A,HW + rs_x = x.reshape(x.shape[0], self.num_aspects, self.split_size, -1) + # B,A,HW,C/A + rs_x = rs_x.permute(0, 1, 3, 2) + + # B,A,1,C/A + focus = torch.matmul(attn_probs, rs_x) + + # B,C,1,1 + return focus.reshape(focus.shape[0], -1, 1, 1) + + +class MAGCAttention(nn.Module): + def __init__(self, input_dim, num_aspects, bottleneck_dim): + super().__init__() + + self.context = MAGCContext(input_dim, num_aspects) + + self.transform = GCTransform(input_dim, bottleneck_dim) + + def forward(self, x): + context = self.context(x) + + tx = self.transform(context) + + ret = x + tx + + return ret + + +class mCReLU_base(nn.Module): + def __init__(self, n_in, n_out, kernel_size, stride=1, pre_act=False, last_act=True): + super().__init__() + + self.pre_act = pre_act + self.last_act = last_act + self.act = F.relu + + self.conv = nn.Conv2d(n_in, n_out, kernel_size, stride=stride, padding=kernel_size // 2) + self.bn = nn.BatchNorm2d(n_out * 2) + + def forward(self, x): + if self.pre_act: + x = self.act(x) + + x = self.conv(x) + x = torch.cat((x, -x), dim=1) + x = self.bn(x) + + if self.last_act: + x = self.act(x) + + return x + + +class mCReLU_residual(nn.Module): + def __init__( + self, + n_in, + n_red, + n_kernel, + n_out, + kernel_size=3, + in_stride=1, + proj=False, + pre_act=False, + last_act=True, + ): + super().__init__() + + self.pre_act = pre_act + self.last_act = last_act + self.stride = in_stride + self.act = F.relu + + self.reduce = nn.Conv2d(n_in, n_red, 1, stride=in_stride) + self.conv = nn.Conv2d(n_red, n_kernel, kernel_size, padding=kernel_size // 2) + self.bn = nn.BatchNorm2d(n_kernel * 2) + self.expand = nn.Conv2d(n_kernel * 2, n_out, 1) + + if in_stride > 1: + assert proj + + self.proj = nn.Conv2d(n_in, n_out, 1, stride=in_stride) if proj else None + + def forward(self, x): + x_sc = x + + if self.pre_act: + x = self.act(x) + + x = self.reduce(x) + x = self.act(x) + + x = self.conv(x) + x = torch.cat((x, -x), 1) + x = self.bn(x) + x = self.act(x) + + x = self.expand(x) + + if self.last_act: + x = self.act(x) + + if self.proj: + x_sc = self.proj(x_sc) + + x = x + x_sc + + return x + + +class Inception(nn.Module): + def __init__(self, n_in, n_out, in_stride=1, preAct=False, lastAct=True, proj=False): + super(Inception, self).__init__() + + # Config + self._preAct = preAct + self._lastAct = lastAct + self.n_in = n_in + self.n_out = n_out + self.act_func = nn.ReLU + self.act = F.relu + self.in_stride = in_stride + + self.n_branches = 0 + self.n_outs = [] # number of output feature for each branch + + self.proj = nn.Conv2d(n_in, n_out, 1, stride=in_stride) if proj else None + + def add_branch(self, module, n_out): + # Create branch + br_name = "branch_{}".format(self.n_branches) + setattr(self, br_name, module) + + # Last output chns. + self.n_outs.append(n_out) + + self.n_branches += 1 + + def branch(self, idx): + br_name = "branch_{}".format(idx) + return getattr(self, br_name, None) + + def add_convs(self, n_kernels, n_chns): + assert len(n_kernels) == len(n_chns) + + n_last = self.n_in + layers = [] + + stride = -1 + for k, n_out in zip(n_kernels, n_chns): + if stride == -1: + stride = self.in_stride + else: + stride = 1 + + # Initialize params + conv = nn.Conv2d( + n_last, n_out, kernel_size=k, bias=False, padding=int(k / 2), stride=stride + ) + bn = nn.BatchNorm2d(n_out) + + # Instantiate network + layers.append(conv) + layers.append(bn) + layers.append(self.act_func()) + + n_last = n_out + + self.add_branch(nn.Sequential(*layers), n_last) + + return self + + def add_poolconv(self, kernel, n_out, type="MAX"): + assert type in ["AVE", "MAX"] + + n_last = self.n_in + layers = [] + + # Pooling + if type == "MAX": + layers.append(nn.MaxPool2d(kernel, padding=int(kernel / 2), stride=self.in_stride)) + elif type == "AVE": + layers.append(nn.AvgPool2d(kernel, padding=int(kernel / 2), stride=self.in_stride)) + + # Conv - BN - Act + layers.append(nn.Conv2d(n_last, n_out, kernel_size=1)) + layers.append(nn.BatchNorm2d(n_out)) + layers.append(self.act_func()) + + self.add_branch(nn.Sequential(*layers), n_out) + + return self + + def finalize(self): + # Add 1x1 convolution + total_outs = sum(self.n_outs) + + self.last_conv = nn.Conv2d(total_outs, self.n_out, kernel_size=1) + self.last_bn = nn.BatchNorm2d(self.n_out) + + return self + + def forward(self, x): + x_sc = x + + if self._preAct: + x = self.act(x) + + # Compute branches + h = [] + for i in range(self.n_branches): + module = self.branch(i) + assert module is not None + + h.append(module(x)) + + x = torch.cat(h, dim=1) + + x = self.last_conv(x) + x = self.last_bn(x) + + if self._lastAct: + x = self.act(x) + + if x_sc.get_device() != x.get_device(): + print("Something's wrong") + + # Projection + if self.proj: + x_sc = self.proj(x_sc) + + x = x + x_sc + + return x diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/__init__.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb59f61a04031738fb505e42b3f5c821c9f54ad5 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/aspp.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/aspp.py new file mode 100644 index 0000000000000000000000000000000000000000..7ceb57b0e390d027310449af5e1b254e0c3e645d --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/aspp.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Atrous Spatial Pyramid Pooling implementation.""" + +import torch +from torch import nn + + +def _grow(rate, power): + if isinstance(rate, (list, tuple)): + return tuple(_grow(r, power) for r in rate) + return int(rate**power) + + +class ASPP(nn.Module): + """A class definining an ASPP module.""" + + def __init__(self, in_channels, num_channels, dropout=0.0, growth_rate=2): + """Initialize an ASPP. + + Args: + in_channels (int): Number of input channels. + num_channels (int): Number of channels in each branch of the ASPP. + norm_type (str): Type of normalization layer, supported: 'batch_norm', + 'sync_batch_norm', 'group_norm'. Default: 'off'. + norm_args (dict): Additional arguments given to the normalization layer. This + includes for example: + - In case 'norm_type' == 'batch_norm' or 'sync_batch_norm', the + 'momentum' parameter. + - In case 'norm_type' == 'group_norm', this includes the 'num_groups' + parameter. + """ + super().__init__() + + norm_layer = nn.BatchNorm2d + use_bias = False + + kernels = [] + num_kernels = 7 + for i in range(num_kernels): + dilation = _grow(growth_rate, i) + kernels.append( + nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=num_channels, + kernel_size=3, + stride=1, + dilation=dilation, + padding=dilation, + bias=use_bias, + ), + norm_layer(num_channels), + nn.ReLU(inplace=True), + ) + ) + + self.kernels = nn.ModuleList(kernels) + + # Global average pooling. + self.global_pool = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d( + in_channels=in_channels, + out_channels=num_channels, + kernel_size=1, + stride=1, + bias=True, + ), + nn.ReLU(inplace=True), + ) + + # Output convolution. + self.final = nn.Sequential( + nn.Conv2d( + in_channels=(1 + num_kernels) * num_channels, + out_channels=num_channels, + kernel_size=1, + stride=1, + bias=use_bias, + ), + norm_layer(num_channels), + nn.ReLU(inplace=True), + nn.Dropout(p=dropout), + ) + + def forward(self, x): + """The module forward function. + + Args: + x (torch.tensor): The input tensor. + + Returns: + torch.tensor: The output tensor. + """ + outs = [kernel(x) for kernel in self.kernels] + + global_pool = self.global_pool(x).expand(-1, -1, *x.shape[2:]) + outs.append(global_pool) + + concatenated = torch.cat(outs, dim=1) + + out = self.final(concatenated) + + if x.shape == out.shape: + return x + out + return out diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/fots_detector.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/fots_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..9521b29ac4fa039fd06d9f09a0d137374ceef5d3 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/fots_detector.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +from nemo_retriever_ocr.inference.models.detector.aspp import ASPP +from nemo_retriever_ocr.inference.models.detector import regnet + +logger = logging.getLogger(__name__) + + +def get_prior_offsets(output_shape, downsample): + """ + Returns the locations of the priors in normalized image space. + + Args: + shape (tensor): Contains the output layer dimensions as [height, width]. This is used to + normalize the prior offsets. + + Returns: + priors (HxWx2 tensor): contains prior offsets in normalized coordinates. + second dimension contains the x, y offsets. + """ + x_priors = torch.arange(0, output_shape[1], dtype=torch.float16) * downsample + y_priors = torch.arange(0, output_shape[0], dtype=torch.float16) * downsample + + x_priors += downsample / 2 + y_priors += downsample / 2 + + x_priors = x_priors.reshape(1, -1, 1).repeat(output_shape[0], 1, 1) + y_priors = y_priors.reshape(-1, 1, 1).repeat(1, output_shape[1], 1) + + priors = torch.cat((x_priors, y_priors), dim=2) + + return priors + + +class extractor(nn.Module): + def __init__(self, backbone="regnet_y_8gf"): + super().__init__() + + backbone = getattr(regnet, backbone)(pretrained=True) + + self.depths = backbone.channel_counts + self.base = backbone.stem + self.levels = nn.ModuleList(backbone.trunk_output) + + self.step = 0 + + self.downsample = 4 + + def set_current_and_total_steps(self, current_step, total_steps): + self.step = current_step + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + # logger.info(f'Input shape: {x.shape}') + x = self.base(x) + + # out = [x] + out = [] + for m in self.levels: + x = m(x) + + out.append(x) + + # logger.info(f'Extraction levels: {[t.shape for t in out]}') + + if self.training: + self.step += 1 + return tuple(out) + + +def conv_block(*args, **kwargs): + conv = nn.Conv2d(*args, bias=False, **kwargs) + bn = nn.BatchNorm2d(conv.out_channels) + return nn.Sequential( + conv, + bn, + nn.ReLU(inplace=True), + ) + + +class merge(nn.Module): + def __init__(self, extractor_depths): + super().__init__() + + # Go from deepest to most shallow + extractor_depths = extractor_depths[::-1] + + pre_upsamples = [] + pre_sides = [] + post_upsamples = [] + next_depth = 512 + prev_depth = extractor_depths[0] + num_features = [extractor_depths[0]] + for i in range(1, len(extractor_depths)): + ds_depth = min(prev_depth // 2, 512) + # side_depth = extractor_depths[i] // 2 + side_depth = extractor_depths[i] + depth = side_depth + ds_depth + + pre_upsamples.append(nn.Sequential(conv_block(prev_depth, ds_depth, 1))) + pre_sides.append( + nn.Sequential( + # conv_block(extractor_depths[i], side_depth, 1) + nn.Identity() + ) + ) + post_upsamples.append( + nn.Sequential( + conv_block(depth, next_depth, 1), + conv_block(next_depth, next_depth, 3, padding=1), + ) + ) + num_features.append(next_depth) + prev_depth = next_depth + next_depth //= 2 + + self.pre_upsamples = nn.ModuleList(pre_upsamples) + self.pre_sides = nn.ModuleList(pre_sides) + self.post_upsamples = nn.ModuleList(post_upsamples) + self.final = nn.Sequential( + conv_block(prev_depth, prev_depth, 3, padding=1), + ) + + self.num_features = num_features + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + # From deepest to most shallow + x = x[::-1] + feats = [x[0]] + + y = x[0] + for i in range(len(x) - 1): + y = self.pre_upsamples[i](y) + y = self.interpolate(y) + side = self.pre_sides[i](x[i + 1]) + y = torch.cat((y, side), 1) + y = self.post_upsamples[i](y) + feats.append(y) + + y = self.final(y) + + feats[-1] = y + + return tuple(feats) + + def interpolate(self, x: torch.Tensor): + if x.dtype != torch.bfloat16: + return F.interpolate(x, scale_factor=2, mode="nearest") + else: + # TODO(mranzinger): Currently F.interpolate doesn't support bfloat16 + x = x.reshape(x.shape[0], x.shape[1], x.shape[2], 1, x.shape[3], 1) + x = x.repeat(1, 1, 1, 2, 1, 2) + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[4] * 2) + return x + + +class output(nn.Module): + def __init__(self, num_features, downsample, coordinate_mode, scope=512): + super().__init__() + + self._prior_offsets = dict() + + self.coordinate_mode = coordinate_mode + + self.downsample = downsample + + num_features = num_features[-1] + + self.scope = scope + + self.slices = [("confidence", slice(0, 1))] + if self.do_quads(): + end = self.slices[-1][1].stop + self.slices.append(("quads", slice(end, end + 8))) + if self.do_rbox(): + end = self.slices[-1][1].stop + self.slices.append(("rbox_coord", slice(end, end + 4))) + self.slices.append(("rbox_rot", slice(end + 4, end + 5))) + + self.preds = nn.Sequential( + ASPP(num_features, num_features), + ASPP(num_features, num_features), + conv_block(num_features, num_features, 3, padding=1), + nn.Conv2d(num_features, self.slices[-1][1].stop, 1, bias=False), + ) + + self.slices = {k: v for k, v in self.slices} + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.uniform_(m.weight, -0.01, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def get_prior_offsets(self, y): + output_size = (y.shape[-2], y.shape[-1]) + + if output_size in self._prior_offsets: + return self._prior_offsets[output_size].to(y) + + # HWC + prior_offsets = get_prior_offsets(output_size, self.downsample) + # CHW + prior_offsets = prior_offsets.permute(2, 0, 1).contiguous() + # (4C)HW + prior_offsets = prior_offsets.repeat(4, 1, 1) + # BCHW + prior_offsets = prior_offsets.unsqueeze(0) + + prior_offsets = prior_offsets.to(y) + + self._prior_offsets[output_size] = prior_offsets + + return prior_offsets + + def adjust_offsets(self, offsets): + return offsets + self.get_prior_offsets(offsets) + + def do_quads(self): + return self.coordinate_mode in ("QUAD", "BOTH") + + def do_rbox(self): + return self.coordinate_mode in ("RBOX", "BOTH") + + def forward( + self, feats: List[torch.Tensor] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], List[torch.Tensor]]: + x = feats[-1] + + preds = self.preds(x) + + conf = preds[:, self.slices["confidence"]].squeeze(1).contiguous() + + if self.do_quads(): + offsets = preds[:, self.slices["quads"]] + offsets = torch.tanh(offsets) * self.scope + offsets = self.adjust_offsets(offsets) + offsets = offsets.permute(0, 2, 3, 1).contiguous() + offsets = offsets.reshape(*offsets.shape[:-1], 4, 2) + else: + offsets = None + + if self.do_rbox(): + rboxes = preds[:, self.slices["rbox_coord"]] + rboxes = F.relu(rboxes, inplace=True) * self.scope + + rot = preds[:, self.slices["rbox_rot"]] + rot = F.hardtanh(rot, min_val=-1, max_val=1) * math.pi + + rboxes = torch.cat((rboxes, rot), dim=1) + rboxes = rboxes.permute(0, 2, 3, 1).contiguous() + else: + rboxes = None + + return conf, offsets, rboxes, x + + +class NaNHook: + def __init__(self, name): + self.name = name + logger.info(f"Hooking {name}") + + def __call__(self, module, input, output): + if module.training: + return + + def nan_test(t): + if torch.any(torch.isnan(t)): + print(f"Checking {self.name}...") + print("input\n", input) + print("output\n", output) + raise ValueError(f'Module {module} with name "{self.name}" produced a nan value!') + + if isinstance(output, (list, tuple)): + for t in output: + nan_test(t) + else: + nan_test(output) + + +class GradReversalFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + # The view_as forces pytorch to call backward + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.neg() + + +class GradReversal(nn.Module): + def forward(self, x): + return GradReversalFunction.apply(x) + + +class FOTSDetector(nn.Module): + def __init__( + self, verbose=True, coordinate_mode: str = "RBOX", backbone: str = "regnet_y_8gf", **kwargs + ): + super().__init__() + + self.extractor = extractor(backbone, **kwargs) + self.merge = merge(self.extractor.depths) + self.num_features = self.merge.num_features + self.output = output(self.num_features, self.extractor.downsample, coordinate_mode) + self.verbose = verbose + self.inference_mode = False + + self.downsample = self.extractor.downsample + + self.register_buffer( + "input_mean", + torch.tensor([0.485, 0.456, 0.406], dtype=torch.float16).reshape(1, -1, 1, 1), + ) + self.register_buffer( + "input_std", + torch.tensor([0.229, 0.224, 0.225], dtype=torch.float16).reshape(1, -1, 1, 1), + ) + + def set_current_and_total_steps(self, current_step, total_steps): + self.extractor.set_current_and_total_steps(current_step, total_steps) + + def forward( + self, x: torch.Tensor + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + ]: + x = (x - self.input_mean) / self.input_std + feats = self.extractor(x) + + mg = self.merge(feats) + + main_op = self.output(mg) + + return main_op diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/regnet.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3fa87a8f73c575dbf75a7c7db42292395e029142 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/detector/regnet.py @@ -0,0 +1,798 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Modified from +# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/anynet.py +# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py + + +import math +import torch + +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, List, Optional, Tuple +from torch import nn, Tensor + +from torchvision._internally_replaced_utils import load_state_dict_from_url + +__all__ = [ + "RegNet", + "regnet_y_400mf", + "regnet_y_800mf", + "regnet_y_1_6gf", + "regnet_y_3_2gf", + "regnet_y_8gf", + "regnet_y_16gf", + "regnet_y_32gf", + "regnet_x_400mf", + "regnet_x_800mf", + "regnet_x_1_6gf", + "regnet_x_3_2gf", + "regnet_x_8gf", + "regnet_x_16gf", + "regnet_x_32gf", +] + + +model_urls = { + "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", +} + + +def barrier(): + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + +def get_rank() -> int: + if not torch.distributed.is_initialized(): + return 0 + return torch.distributed.get_rank() + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvNormActivation(torch.nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: bool = True, + ) -> None: + if padding is None: + padding = (kernel_size - 1) // 2 * dilation + layers = [ + torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=norm_layer is None, + ) + ] + if norm_layer is not None: + layers.append(norm_layer(out_channels)) + if activation_layer is not None: + layers.append(activation_layer(inplace=inplace)) + super().__init__(*layers) + self.out_channels = out_channels + + +class SqueezeExcitation(torch.nn.Module): + def __init__( + self, + input_channels: int, + squeeze_channels: int, + activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, + scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, + ) -> None: + super().__init__() + self.avgpool = torch.nn.AdaptiveAvgPool2d(1) + self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) + self.activation = activation() + self.scale_activation = scale_activation() + + def _scale(self, input: Tensor) -> Tensor: + scale = self.avgpool(input) + scale = self.fc1(scale) + scale = self.activation(scale) + scale = self.fc2(scale) + return self.scale_activation(scale) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input) + return scale * input + + +class SimpleStemIN(ConvNormActivation): + """Simple stem for ImageNet: 3x3, BN, ReLU.""" + + def __init__( + self, + width_in: int, + width_out: int, + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + ) -> None: + super().__init__( + width_in, + width_out, + kernel_size=3, + stride=2, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + + +class BottleneckTransform(nn.Sequential): + """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1.""" + + def __init__( + self, + width_in: int, + width_out: int, + stride: int, + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + group_width: int, + bottleneck_multiplier: float, + se_ratio: Optional[float], + ) -> None: + layers: OrderedDict[str, nn.Module] = OrderedDict() + w_b = int(round(width_out * bottleneck_multiplier)) + g = w_b // group_width + + layers["a"] = ConvNormActivation( + width_in, + w_b, + kernel_size=1, + stride=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + layers["b"] = ConvNormActivation( + w_b, + w_b, + kernel_size=3, + stride=stride, + groups=g, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + + if se_ratio: + # The SE reduction ratio is defined with respect to the + # beginning of the block + width_se_out = int(round(se_ratio * width_in)) + layers["se"] = SqueezeExcitation( + input_channels=w_b, + squeeze_channels=width_se_out, + activation=activation_layer, + ) + + layers["c"] = ConvNormActivation( + w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None + ) + super().__init__(layers) + + +class ResBottleneckBlock(nn.Module): + """Residual bottleneck block: x + F(x), F = bottleneck transform.""" + + def __init__( + self, + width_in: int, + width_out: int, + stride: int, + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + group_width: int = 1, + bottleneck_multiplier: float = 1.0, + se_ratio: Optional[float] = None, + ) -> None: + super().__init__() + + # Use skip connection with projection if shape changes + self.proj = None + should_proj = (width_in != width_out) or (stride != 1) + if should_proj: + self.proj = ConvNormActivation( + width_in, + width_out, + kernel_size=1, + stride=stride, + norm_layer=norm_layer, + activation_layer=None, + ) + self.f = BottleneckTransform( + width_in, + width_out, + stride, + norm_layer, + activation_layer, + group_width, + bottleneck_multiplier, + se_ratio, + ) + self.activation = activation_layer(inplace=True) + + def forward(self, x: Tensor) -> Tensor: + if self.proj is not None: + x = self.proj(x) + self.f(x) + else: + x = x + self.f(x) + return self.activation(x) + + +class AnyStage(nn.Sequential): + """AnyNet stage (sequence of blocks w/ the same output shape).""" + + def __init__( + self, + width_in: int, + width_out: int, + stride: int, + depth: int, + block_constructor: Callable[..., nn.Module], + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + group_width: int, + bottleneck_multiplier: float, + se_ratio: Optional[float] = None, + stage_index: int = 0, + ) -> None: + super().__init__() + + for i in range(depth): + block = block_constructor( + width_in if i == 0 else width_out, + width_out, + stride if i == 0 else 1, + norm_layer, + activation_layer, + group_width, + bottleneck_multiplier, + se_ratio, + ) + + self.add_module(f"block{stage_index}-{i}", block) + + +class BlockParams: + def __init__( + self, + depths: List[int], + widths: List[int], + group_widths: List[int], + bottleneck_multipliers: List[float], + strides: List[int], + se_ratio: Optional[float] = None, + ) -> None: + self.depths = depths + self.widths = widths + self.group_widths = group_widths + self.bottleneck_multipliers = bottleneck_multipliers + self.strides = strides + self.se_ratio = se_ratio + + @classmethod + def from_init_params( + cls, + depth: int, + w_0: int, + w_a: float, + w_m: float, + group_width: int, + bottleneck_multiplier: float = 1.0, + se_ratio: Optional[float] = None, + **kwargs: Any, + ) -> "BlockParams": + """ + Programatically compute all the per-block settings, + given the RegNet parameters. + + The first step is to compute the quantized linear block parameters, + in log space. Key parameters are: + - `w_a` is the width progression slope + - `w_0` is the initial width + - `w_m` is the width stepping in the log space + + In other terms + `log(block_width) = log(w_0) + w_m * block_capacity`, + with `bock_capacity` ramping up following the w_0 and w_a params. + This block width is finally quantized to multiples of 8. + + The second step is to compute the parameters per stage, + taking into account the skip connection and the final 1x1 convolutions. + We use the fact that the output width is constant within a stage. + """ + + QUANT = 8 + STRIDE = 2 + + if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0: + raise ValueError("Invalid RegNet settings") + # Compute the block widths. Each stage has one unique block width + widths_cont = torch.arange(depth) * w_a + w_0 + block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m)) + block_widths = ( + (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT) + .int() + .tolist() + ) + num_stages = len(set(block_widths)) + + # Convert to per stage parameters + split_helper = zip( + block_widths + [0], + [0] + block_widths, + block_widths + [0], + [0] + block_widths, + ) + splits = [w != wp or r != rp for w, wp, r, rp in split_helper] + + stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t] + stage_depths = ( + torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist() + ) + + strides = [STRIDE] * num_stages + bottleneck_multipliers = [bottleneck_multiplier] * num_stages + group_widths = [group_width] * num_stages + + # Adjust the compatibility of stage widths and group widths + stage_widths, group_widths = cls._adjust_widths_groups_compatibilty( + stage_widths, bottleneck_multipliers, group_widths + ) + + return cls( + depths=stage_depths, + widths=stage_widths, + group_widths=group_widths, + bottleneck_multipliers=bottleneck_multipliers, + strides=strides, + se_ratio=se_ratio, + ) + + def _get_expanded_params(self): + return zip( + self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers + ) + + @staticmethod + def _adjust_widths_groups_compatibilty( + stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int] + ) -> Tuple[List[int], List[int]]: + """ + Adjusts the compatibility of widths and groups, + depending on the bottleneck ratio. + """ + # Compute all widths for the current settings + widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)] + group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)] + + # Compute the adjusted widths so that stage and group widths fit + ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)] + stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)] + return stage_widths, group_widths_min + + +class RegNet(nn.Module): + def __init__( + self, + block_params: BlockParams, + num_classes: int = 1000, + stem_width: int = 32, + stem_type: Optional[Callable[..., nn.Module]] = None, + block_type: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + + if stem_type is None: + stem_type = SimpleStemIN + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if block_type is None: + block_type = ResBottleneckBlock + if activation is None: + activation = nn.ReLU + + # Ad hoc stem + self.stem = stem_type( + 3, # width_in + stem_width, + norm_layer, + activation, + ) + + current_width = stem_width + + self.channel_counts = block_params.widths + + blocks = [] + for i, ( + width_out, + stride, + depth, + group_width, + bottleneck_multiplier, + ) in enumerate(block_params._get_expanded_params()): + blocks.append( + ( + f"block{i+1}", + AnyStage( + current_width, + width_out, + stride, + depth, + block_type, + norm_layer, + activation, + group_width, + bottleneck_multiplier, + block_params.se_ratio, + stage_index=i + 1, + ), + ) + ) + + current_width = width_out + + self.trunk_output = nn.Sequential(OrderedDict(blocks)) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(in_features=current_width, out_features=num_classes) + + # Init weights and good to go + self._reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + x = self.trunk_output(x) + + x = self.avgpool(x) + x = x.flatten(start_dim=1) + x = self.fc(x) + + return x + + def _reset_parameters(self) -> None: + # Performs ResNet-style weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # Note that there is no bias due to BN + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out)) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + + +def _regnet( + arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any +) -> RegNet: + model = RegNet( + block_params, norm_layer=partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1), **kwargs + ) + if pretrained: + if arch not in model_urls: + raise ValueError(f"No checkpoint is available for model type {arch}") + # Essentially, let rank 0 fetch the model first, which may trigger a download. + # this prevents multiple processes/nodes from corrupting the cache + barrier() + if get_rank() == 0: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + barrier() + if get_rank() > 0: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_400MF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=14, w_0=32, w_a=27.89, w_m=2.09, group_width=32, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) + + +def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_800MF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=16, w_0=64, w_a=38.84, w_m=2.4, group_width=32, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) + + +def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_1.6GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=27, w_0=64, w_a=26.71, w_m=2.65, group_width=32, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) + + +def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_3.2GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=21, w_0=64, w_a=42.63, w_m=2.66, group_width=32, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) + + +def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_8GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + params = BlockParams.from_init_params( + depth=17, + w_0=192, + w_a=76.82, + w_m=2.19, + group_width=56 if pretrained else 64, + se_ratio=0.25, + **kwargs, + ) + return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) + + +def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_16GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + params = BlockParams.from_init_params( + depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) + + +def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_32GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + params = BlockParams.from_init_params( + depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs + ) + return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) + + +def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_400MF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=22, w_0=32, w_a=24.48, w_m=2.54, group_width=32, **kwargs + ) + return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs) + + +def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_800MF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=16, w_0=64, w_a=35.73, w_m=2.28, group_width=32, **kwargs + ) + return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs) + + +def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_1.6GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=18, w_0=64, w_a=34.01, w_m=2.25, group_width=32, **kwargs + ) + return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs) + + +def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_3.2GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=25, w_0=64, w_a=26.31, w_m=2.25, group_width=64, **kwargs + ) + return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs) + + +def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_8GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + params = BlockParams.from_init_params( + depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs + ) + else: + params = BlockParams.from_init_params( + depth=23, w_0=64, w_a=49.56, w_m=2.88, group_width=96, **kwargs + ) + return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs) + + +def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_16GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + params = BlockParams.from_init_params( + depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs + ) + return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs) + + +def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_32GF architecture from + `"Designing Network Design Spaces" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + params = BlockParams.from_init_params( + depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs + ) + return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs) + + +# TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/recognizer.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/recognizer.py new file mode 100644 index 0000000000000000000000000000000000000000..28facdae677277ec476127b5deaeb21540828148 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/recognizer.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Optional + +import torch +import torch.nn as nn + +from nemo_retriever_ocr.inference.models import blocks + +logger = logging.getLogger(__name__) + + +class TransformerRecognizer(nn.Module): + def __init__(self, nic: int, num_tokens: int, max_width: int) -> None: + super().__init__() + self.num_tokens = num_tokens + self.fixed_width = max_width > 0 + self.inference_mode = False + depth = 128 + + max_width = abs(max_width) + + final_depth = depth * 2 + + self.feature_depth = final_depth + + CONV_SHAPE = (3, 3) + PAD_SHAPE = tuple((c - 1) // 2 for c in CONV_SHAPE) + + self.encoder = nn.Sequential( + blocks.conv2d_block(nic, nic, 3, padding=1), + blocks.conv2d_block(nic, nic * 2, 3, padding=1), + nn.MaxPool2d((2, 1)), + blocks.conv2d_block(nic * 2, nic * 2, CONV_SHAPE, padding=PAD_SHAPE), + blocks.conv2d_block(nic * 2, depth * 2, CONV_SHAPE, padding=PAD_SHAPE), + nn.MaxPool2d((2, 1)), + blocks.conv2d_block(depth * 2, final_depth, CONV_SHAPE, padding=PAD_SHAPE), + blocks.conv2d_block(final_depth, final_depth * 2, CONV_SHAPE, padding=PAD_SHAPE), + nn.MaxPool2d((2, 1)), + blocks.conv2d_block(final_depth * 2, final_depth, 1), + ) + + self.position_encoding = nn.Parameter(torch.randn(1, final_depth, max_width)) + + self.tx = nn.TransformerEncoder( + nn.TransformerEncoderLayer(final_depth, 8, final_depth, dropout=0.0, batch_first=True), + num_layers=3, + ) + self.classifier = nn.Linear(final_depth, num_tokens) + + def forward(self, x: torch.Tensor, cpu_region_counts: Optional[torch.Tensor] = None): + x = self.encoder(x).squeeze(2) + + if self.fixed_width: + pos = self.position_encoding + else: + pos = self.position_encoding[..., : x.shape[2]] + x = x + pos + + # B,T,C + x = x.permute(0, 2, 1).contiguous() + + x = self.tx(x) + + y = self.classifier(x) + + return y, x diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/relational.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/relational.py new file mode 100644 index 0000000000000000000000000000000000000000..e93a03c72f1aa81138963b2cef095defb2c28af9 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/relational.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import math + +import torch +import torch.nn as nn + +from nemo_retriever_ocr.inference.models import blocks +from nemo_retriever_ocr.inference.models.utils import options + +from nemo_retriever_ocr_cpp import ( + quad_rectify_calc_quad_width, + ragged_quad_all_2_all_distance_v2, +) + + +logger = logging.getLogger(__name__) + + +NULL_CONNECTION_WEIGHT = -math.inf + + +class GlobalRelationalModel(nn.Module): + def __init__(self, num_input_channels, recog_feature_depth, k=32, dropout=0.1, num_layers=4): + super().__init__() + + num_input_channels = num_input_channels[-1] + self.pos_channels = 14 # 64 + self.current_step = 0 + self.total_steps = 1 + self.k = k + self.quad_rectify_grid_size = (2, 3) + self.quad_downscale = 1024.0 + self.inference_mode = False + + self.grid_size = [2, 3] + self.isotropic = False + + self.initial_depth = 128 + + self.rect_proj = blocks.conv2d_block(num_input_channels, num_input_channels, 1) + + self.recog_tx = nn.Linear(recog_feature_depth, num_input_channels) + + # Reserve 2 channels for the joint distance and angle embeddings + initial_depth = self.initial_depth - 1 - self.pos_channels + cb_input = 2 * num_input_channels # + self.pos_channels + self.combined_proj = nn.Sequential( + nn.Linear(cb_input, cb_input), + nn.BatchNorm1d(cb_input), + nn.ReLU(), + nn.Linear(cb_input, initial_depth), + nn.BatchNorm1d(initial_depth), + nn.ReLU(), + ) + + dim = 2 * self.initial_depth + + self.encoder = nn.Sequential( + nn.TransformerEncoder( + nn.TransformerEncoderLayer( + dim, 8, 2 * dim, batch_first=True, dropout=dropout, norm_first=True + ), + num_layers=num_layers, + ), + nn.Linear(dim, 3), + ) + + def get_target_rects( + self, + quads: torch.Tensor, + curr_rects: torch.Tensor, + curr_centers: torch.Tensor, + gt_relations: torch.Tensor, + ): + to_rects = curr_rects.unsqueeze(0).expand(curr_rects.shape[0], -1, -1) + to_centers = curr_centers.unsqueeze(0).expand(quads.shape[0], -1, -1) + + k = max(0, min(to_rects.shape[1] - 1, self.k - 1)) + + if k == 0: + to_rects = torch.zeros( + curr_rects.shape[0], 1, curr_rects.shape[1] + 2, **options(curr_rects) + ) + closest_other_idxs = torch.zeros( + curr_rects.shape[0], 1, dtype=torch.int64, device=curr_rects.device + ) + else: + all_dists = get_cdist(quads, curr_centers) + + closest_other_idxs = torch.topk( + all_dists, k=k, dim=1, largest=False, sorted=False + ).indices + + # K,K-1,D + to_rects = torch.gather( + to_rects, + dim=1, + index=closest_other_idxs.unsqueeze(2).expand(-1, -1, curr_rects.shape[1]), + ) + # K,K-1 + all_dists = torch.gather(all_dists, dim=1, index=closest_other_idxs) + # K,K-1,2 + to_centers = torch.gather( + to_centers, dim=1, index=closest_other_idxs.unsqueeze(2).expand(-1, -1, 2) + ) + + # Add the null column to rects + to_rects = torch.cat( + ( + torch.zeros(to_rects.shape[0], 1, to_rects.shape[2], **options(to_rects)), + to_rects, + ), + dim=1, + ) + + # Add the null column to the dists + all_dists = torch.cat( + ( + torch.full((all_dists.shape[0], 1), -1, **options(all_dists)), + all_dists, + ), + dim=1, + ) + + directions = get_directions(quads, to_centers) + directions = torch.cat( + ( + torch.full((directions.shape[0], 1), -2, **options(directions)), + directions, + ), + dim=1, + ) + + # Add the pairwise geometric encodings + to_rects = torch.cat( + ( + to_rects, + all_dists.unsqueeze(2), + directions.unsqueeze(2), + ), + dim=2, + ) + + # Add the null column + closest_other_idxs = torch.cat( + [ + torch.zeros(closest_other_idxs.shape[0], 1, **options(closest_other_idxs)), + closest_other_idxs + 1, + ], + dim=1, + ) + + return to_rects, closest_other_idxs + + def prohibit_self_connection(self, dots: torch.Tensor, closest_other_idxs: torch.Tensor = None): + dots = dots.float() + + if closest_other_idxs is None: + neg_inf = torch.full((dots.shape[-2],), NULL_CONNECTION_WEIGHT, **options(dots)).diag() + + neg_inf = torch.cat( + (torch.zeros(neg_inf.shape[0], 1, **options(neg_inf)), neg_inf), dim=1 + ) + + if dots.ndim == 3: + neg_inf.unsqueeze_(0) + + dots = dots + neg_inf + else: + neg_inf = torch.full( + (*dots.shape[:-2], dots.shape[-2], dots.shape[-2] + 1), + NULL_CONNECTION_WEIGHT, + **options(dots), + ) + + if dots.ndim == 3: + closest_other_idxs = closest_other_idxs.unsqueeze(0).expand_as(dots) + + neg_inf = torch.scatter(neg_inf, dim=-1, index=closest_other_idxs, src=dots) + dots = neg_inf + + return dots + + def get_input_encoding( + self, + rectified_quads: torch.Tensor, + original_quads: torch.Tensor, + region_counts: torch.Tensor, + recog_features: torch.Tensor, + ): + cs_rg = torch.cumsum(region_counts, 0) + cs_rg = torch.cat([torch.tensor([0]), cs_rg]) + ex_offsets = cs_rg + + g_height, g_width = self.quad_rectify_grid_size + if self.isotropic: + # Figure out the number of valid positions for each quad, which we can use to compute the mean + quad_widths = quad_rectify_calc_quad_width( + original_quads, rectified_quads.shape[-2], 1, rectified_quads.shape[-1] + ) + else: + quad_widths = torch.full((original_quads.shape[0],), g_width, **options(original_quads)) + num_valid_pos = (quad_widths * g_height).clamp_min(1) + + # Ensure that these values aren't very large + original_quads = original_quads / self.quad_downscale + mid_pts = original_quads.detach().mean(dim=1, dtype=torch.float32) + + rectified_quads = self.rect_proj(rectified_quads) + avg_rects = rectified_quads.flatten(2).sum( + dim=2, dtype=torch.float32 + ) / num_valid_pos.unsqueeze(1) + + recog_encoding = self.recog_tx(recog_features.detach()).mean(dim=1, dtype=torch.float32) + + semantic_encoding = self.combined_proj(torch.cat((avg_rects, recog_encoding), dim=1)) + + h1 = original_quads[:, 3] - original_quads[:, 0] + h2 = original_quads[:, 2] - original_quads[:, 1] + + mp1 = original_quads[:, 0] + (h1 / 2) + mp2 = original_quads[:, 1] + (h2 / 2) + + d1 = mp2 - mp1 + + wdth = d1.norm(dim=1, keepdim=True) + + d1 = d1 / wdth.clamp_min(1e-6) + + hts = ((h1 + h2) / 2).norm(dim=1, keepdim=True) + + d2 = torch.stack([-d1[:, 1], d1[:, 0]], dim=-1) + + # Prevent overfitting to specific quad positions by translating all positions + # by some random offset (thus preserving inter-quad relationships, but not absolute positions) + if self.training: + rand_quad_offset = torch.rand(1, 1, 2, **options(original_quads)) * 4 - 2 + + quads_enc = original_quads + rand_quad_offset + else: + quads_enc = original_quads + + # The last 5 tensors represent the geometric encoding for each word + + full_encoding = torch.cat( + (semantic_encoding, quads_enc.flatten(1), d1, d2, wdth, hts), dim=1 + ) + + return full_encoding, ex_offsets, region_counts, mid_pts + + def forward( + self, + rectified_quads: torch.Tensor, + original_quads: torch.Tensor, + region_counts: torch.Tensor, + recog_features: torch.Tensor, + ): + rectified_quads = rectified_quads.float() + recog_features = recog_features.float() + + assert torch.all(torch.isfinite(rectified_quads)) + assert torch.all(torch.isfinite(recog_features)) + + proj_rects, ex_offsets, region_counts, mid_pts = self.get_input_encoding( + rectified_quads, original_quads, region_counts, recog_features + ) + + quads = original_quads / self.quad_downscale + + all_dots = dict(words=[], lines=[], line_log_var_unc=[]) + + if not self.inference_mode: + assert torch.all(torch.isfinite(proj_rects)), "Not all proj_rects were finite!" + + for i, (offset, region_count) in enumerate(zip(ex_offsets, region_counts)): + # K,D + curr_rects = proj_rects[offset : offset + region_count] + curr_centers = mid_pts[offset : offset + region_count] + curr_quads = quads[offset : offset + region_count] + + from_rects = curr_rects + + to_rects, closest_other_idxs = self.get_target_rects( + curr_quads, curr_rects, curr_centers, None + ) + + # K,Z,D + from_rects = from_rects.unsqueeze(1).expand(-1, to_rects.shape[1], -1) + + # K,Z+1,D*2 + enc_input = torch.cat((from_rects, to_rects), dim=2) + + # K,Z+1,2 + if enc_input.shape[0]: + dots = self.encoder(enc_input) + else: + dots = torch.empty(0, 1, 3, dtype=enc_input.dtype, device=enc_input.device) + + # 2,K,Z+1 + dots = dots.permute(2, 0, 1) + + dots = self.prohibit_self_connection(dots, closest_other_idxs) + + word_pred = dots[0] + line_pred = dots[1] + line_log_var_pred = dots[2] + + all_dots["words"].append(word_pred) + all_dots["lines"].append(line_pred) + all_dots["line_log_var_unc"].append(line_log_var_pred) + + return { + "words": all_dots["words"], + "lines": all_dots["lines"], + "line_log_var_unc": all_dots["line_log_var_unc"], + } + + +def get_cdist( + quads: torch.Tensor, centers: torch.Tensor, x_factor: float = 1.0, y_factor: float = 1.0 +): + region_counts = torch.tensor([quads.shape[0]], dtype=torch.int64, device=quads.device) + + ret = ragged_quad_all_2_all_distance_v2( + quads.unsqueeze(0), region_counts, x_factor, y_factor, allow_self_distance=False + )[0] + + return ret + + +def get_directions(quads: torch.Tensor, to_centers: torch.Tensor): + quads = quads.detach() + to_centers = to_centers.detach() + + # quads: N,4,2 + # to_centers: N,K,2 + + pt0 = (quads[:, 0] + quads[:, 3]) / 2 + pt1 = (quads[:, 1] + quads[:, 2]) / 2 + + direction = pt1 - pt0 + direction /= direction.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) + direction = direction.unsqueeze(1).expand(-1, to_centers.shape[1], -1) + + centers = (pt0 + pt1) / 2 + + vec_other = to_centers - centers.unsqueeze(1) + dir_other = vec_other / vec_other.norm(p=2, dim=2, keepdim=True).clamp_min(1e-6) + + cos_angle = torch.einsum("ftv,ftv->ft", direction, dir_other) + + return cos_angle diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/utils.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c867b2c2a7652c5ab5e18d9ea6eec44329ea26d7 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/models/utils.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from typing import Dict, Any + +import torch + +logger = logging.getLogger(__name__) + + +def is_named_tuple(obj): + """ + Return where or not the specified instance is a namedtuple. + + NOTE: Not guaranteed to be correct, but close. + + Args: + obj (object): Some object to test. + """ + return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None + + +def find_checkpoint(checkpoint_path): + checkpoint_path = os.path.join(checkpoint_path, "best_checkpoint.pth") + + return checkpoint_path + + +def cat(tensors, *rest_shape, dtype=torch.float32) -> torch.Tensor: + if tensors: + return torch.cat(tensors) + else: + return torch.empty(0, *rest_shape, dtype=dtype) + + +def options(tensor: torch.Tensor) -> Dict[str, Any]: + """ + Returns as a dict the dtype and device options for a tensor. This allows you + to construct a new tensor with a compatible format. + + e.g. + new_tensor = torch.empty(, **options(other_tensor)) + """ + return {"dtype": tensor.dtype, "device": tensor.device} + + +def f_measure(*args): + acc = 0 + for v in args: + if torch.is_tensor(v): + v = v.clamp_min(1e-8) + elif v <= 0: + v = 1e-8 + acc += 1.0 / v + + fmeasure = len(args) / acc + + return fmeasure + + +def tensor_all_reduce(tensor, reduce_op=torch.distributed.ReduceOp.SUM): + if torch.distributed.is_initialized(): + was_cuda = tensor.is_cuda + tensor = tensor.cuda(non_blocking=True) + torch.distributed.all_reduce(tensor, reduce_op) + if not was_cuda: + tensor = tensor.cpu() + return tensor + + +def tensor_all_gather(tensor: torch.Tensor, dim=0): + if not torch.distributed.is_initialized(): + return tensor + + tensor = tensor.contiguous() + orig_dtype = tensor.dtype + if tensor.dtype == torch.bool: + tensor = tensor.to(torch.int64) + orig_device = tensor.device + if not tensor.is_cuda: + tensor = tensor.cuda(non_blocking=True) + # Scalar tensor + if len(tensor.shape) == 0: + tensor = tensor.reshape(1) + + buffers = [torch.empty_like(tensor) for i in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(buffers, tensor) + + return torch.cat(buffers, dim=dim).to(dtype=orig_dtype, device=orig_device) diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/pipeline.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..707cbd209d16cad40de18220b3474e5f46d2ebd7 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/pipeline.py @@ -0,0 +1,400 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import io +import json +import os +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from nemo_retriever_ocr.inference.encoders.recognizer_encoder import RecognitionTargetEncoder +from nemo_retriever_ocr.inference.encoders.relational_encoder import RelationalTargetEncoder +from nemo_retriever_ocr.inference.models.detector.fots_detector import FOTSDetector +from nemo_retriever_ocr.inference.models.recognizer import TransformerRecognizer +from nemo_retriever_ocr.inference.models.relational import GlobalRelationalModel +from nemo_retriever_ocr.inference.post_processing.indirect_grid_sample import IndirectGridSample +from nemo_retriever_ocr.inference.post_processing.data.text_region import TextBlock +from nemo_retriever_ocr.inference.post_processing.quad_rectify import QuadRectify +from nemo_retriever_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes +from nemo_retriever_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square +from nemo_retriever_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads +from PIL import Image, ImageDraw, ImageFont +from torch import amp +from torchvision.io import read_image, decode_image +from torchvision.transforms.functional import convert_image_dtype + +PAD_COLOR = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float16) +INFER_LENGTH = 1024 +DETECTOR_DOWNSAMPLE = 4 +NMS_PROB_THRESHOLD = 0.5 +NMS_IOU_THRESHOLD = 0.5 +NMS_MAX_REGIONS = 0 + +MERGE_LEVELS = {"word", "sentence", "paragraph"} +DEFAULT_MERGE_LEVEL = "paragraph" + + +class NemoRetrieverOCR: + """ + A high-level pipeline for performing OCR on images. + """ + + def __init__(self, model_dir="./checkpoints"): + self._model_dir = Path(model_dir) + + self._load_models() + self._load_charset() + self._initialize_processors() + + def _load_models(self): + """Loads all necessary models into memory.""" + self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False) + self.detector.load_state_dict(torch.load(self._model_dir / "detector.pth"), strict=True) + + self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32) + self.recognizer.load_state_dict(torch.load(self._model_dir / "recognizer.pth"), strict=True) + + self.relational = GlobalRelationalModel( + num_input_channels=self.detector.num_features, + recog_feature_depth=self.recognizer.feature_depth, + dropout=0.1, + k=16, + num_layers=4, + ) + self.relational.load_state_dict(torch.load(self._model_dir / "relational.pth"), strict=True) + + for model in (self.detector, self.recognizer, self.relational): + model = model.cuda() + model.eval() + model.inference_mode = True + + def _initialize_processors(self): + """Initializes helper classes for pre/post-processing.""" + self.recognizer_quad_rectifier = QuadRectify(8, 32) + self.relational_quad_rectifier = QuadRectify(2, 3, isotropic=False) + self.grid_sampler = IndirectGridSample() + + self.recog_encoder = RecognitionTargetEncoder( + charset=self.charset, + input_size=[1024, 1920], + sequence_length=32, + amp_opt=2, + combine_duplicates=False, + is_train=False, + ) + self.relation_encoder = RelationalTargetEncoder(input_size=[1024, 1920], amp_opt=2, is_train=False) + + def _load_charset(self): + with open(self._model_dir / "charset.txt", "r", encoding="utf-8") as file: + self.charset = json.load(file) + + def __call__(self, image, merge_level=DEFAULT_MERGE_LEVEL, visualize=False): + """ + Performs OCR on a single image. + + Args: + image (str | bytes | np.ndarray | Image.Image): The input image. Can be a: + - file path (str) + - base64 encoded string (bytes) + - NumPy array (H, W, C) + - In-memory byte stream (io.BytesIO) + merge_level (str): The granularity of text merging ('word', 'sentence', 'paragraph'). + visualize (bool): If True, saves an annotated image. + + Returns: + list: A list of prediction dictionaries. + """ + image_tensor = self._load_image_to_tensor(image) + + predictions = self._process_tensor(image_tensor, merge_level) + + original_path = image if isinstance(image, str) and Path(image).is_file() else None + if visualize: + if original_path is None: + raise ValueError("Visualization is only supported when the input is a file path.") + self._save_annotated_image(original_path, predictions) + + return predictions + + def _load_image_to_tensor(self, image): + """ + Loads an image from various sources and converts it to a standardized tensor. + """ + if isinstance(image, str): + image_path = Path(image) + if not image_path.is_file(): + raise FileNotFoundError(f"Input string is not a valid file path: {image}") + img_tensor = read_image(str(image_path), mode="RGB") + + elif isinstance(image, bytes): + try: + img_bytes = base64.b64decode(image) + img_tensor = decode_image(torch.frombuffer(img_bytes, dtype=torch.uint8), mode="RGB") + except (ValueError, TypeError, base64.binascii.Error) as e: + raise ValueError("Input is not a valid base64-encoded image.") from e + + elif isinstance(image, np.ndarray): + # PyTorch expects CHW, NumPy use HWC, so we permute + if image.ndim == 2: # Handle grayscale by stacking + image = np.stack([image] * 3, axis=-1) + # Handle RGBA images by stripping the alpha channel + if image.shape[2] == 4: + image = image[..., :3] + img_tensor = torch.from_numpy(image).permute(2, 0, 1) + + elif isinstance(image, io.BytesIO): + image.seek(0) + img_bytes = image.getvalue() + img_tensor = decode_image(torch.frombuffer(img_bytes, dtype=torch.uint8), mode="RGB") + + else: + raise TypeError( + f"Unsupported input type: {type(image)}. " + "Supported types are file path (str), base64 (str/bytes), NumPy array, and io.BytesIO" + ) + + return convert_image_dtype(img_tensor, dtype=torch.float16) + + def _process_tensor(self, image_tensor, merge_level): + """ + Runs the core OCR inference pipeline on a standardized image tensor. + """ + if merge_level not in MERGE_LEVELS: + raise ValueError(f"Invalid merge level: {merge_level}. Must be one of {MERGE_LEVELS}.") + + original_shape = image_tensor.shape[1:] + padded_length = max(original_shape) + + padded_image = interpolate_and_pad( + pad_to_square(image_tensor, padded_length, how="bottom_right").unsqueeze(0), + PAD_COLOR, + INFER_LENGTH, + ) + + with amp.autocast("cuda", enabled=True), torch.no_grad(): + det_conf, _, det_rboxes, det_feature_3 = self.detector(padded_image.cuda()) + + with amp.autocast("cuda", enabled=True), torch.no_grad(): + e2e_det_conf = torch.sigmoid(det_conf) + e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE) + + # FIXME: quad_non_maximal_suppression fails with batch size > 1 + all_quads = [] + all_confidence = [] + all_region_counts = [] + + for idx in range(e2e_det_coords.shape[0]): + quads, confidence, region_counts = quad_non_maximal_suppression( + e2e_det_coords[idx].unsqueeze(0), + e2e_det_conf[idx].unsqueeze(0), + prob_threshold=NMS_PROB_THRESHOLD, + iou_threshold=NMS_IOU_THRESHOLD, + kernel_height=2, + kernel_width=3, + max_regions=NMS_MAX_REGIONS, + verbose=False, + )[:3] + all_quads.append(quads) + all_confidence.append(confidence) + all_region_counts.append(region_counts) + + quads = torch.cat(all_quads, dim=0) + confidence = torch.cat(all_confidence, dim=0) + region_counts = torch.cat(all_region_counts, dim=0) + + if quads.shape[0] == 0: + rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device) + rel_rectified_quads = torch.empty(0, 128, 2, 3, dtype=torch.float32, device=padded_image.device) + else: + rec_rectified_quads = self.recognizer_quad_rectifier( + quads.detach(), padded_image.shape[2], padded_image.shape[3] + ) + rel_rectified_quads = self.relational_quad_rectifier( + quads.cuda().detach(), padded_image.shape[2], padded_image.shape[3] + ) + + input_indices = region_counts_to_indices(region_counts, quads.shape[0]) + + rec_rectified_quads = self.grid_sampler(det_feature_3.float(), rec_rectified_quads.float(), input_indices) + rel_rectified_quads = self.grid_sampler( + det_feature_3.float().cuda(), + rel_rectified_quads, + input_indices.cuda(), + ) + + if rec_rectified_quads.shape[0] == 0: + rec_output = torch.empty(0, 32, 858, dtype=torch.float16, device=rec_rectified_quads.device) + rec_features = torch.empty(0, 32, 256, dtype=torch.float16, device=rec_rectified_quads.device) + else: + with amp.autocast("cuda", enabled=True), torch.no_grad(): + rec_output, rec_features = self.recognizer(rec_rectified_quads.cuda()) + + predictions = [] + + if region_counts.sum() > 0: + rel_output = self.relational( + rel_rectified_quads.cuda(), + quads.cuda(), + region_counts.cpu(), + rec_features.cuda(), + ) + words, lines, line_var = ( + rel_output["words"], + rel_output["lines"], + rel_output["line_log_var_unc"], + ) + + with amp.autocast("cuda", enabled=True), torch.no_grad(): + words = [F.softmax(r, dim=1, dtype=torch.float32)[:, 1:] for r in words] + + output = { + "sequences": F.softmax(rec_output, dim=2, dtype=torch.float32), + "region_counts": region_counts, + "quads": quads, + "raw_detector_confidence": e2e_det_conf, + "confidence": confidence, + "relations": words, + "line_relations": lines, + "line_rel_var": line_var, + "fg_colors": None, + "fonts": None, + "tt_log_var_uncertainty": None, + "e2e_recog_features": rec_features, + } + + quads = output["quads"] + + lengths = [padded_length / INFER_LENGTH] * region_counts.item() + + lengths_tensor = torch.tensor(lengths, dtype=torch.float32, device=quads.device).view(quads.shape[0], 1, 1) + + quads *= lengths_tensor + + # TODO: Incorporate the quad scale factor + batch = self.recog_encoder.convert_targets_to_labels(output, image_size=None, is_gt=False) + relation_batch = self.relation_encoder.convert_targets_to_labels(output, image_size=None, is_gt=False) + + for example, rel_example in zip(batch, relation_batch): + example.relation_graph = rel_example.relation_graph + example.prune_invalid_relations() + + for example in batch: + if example.relation_graph is None: + continue + for paragraph in example.relation_graph: + block = [] + for line in paragraph: + for relational_idx in line: + block.append(example[relational_idx]) + if block: + example.blocks.append(TextBlock(block)) + + for example in batch: + for text_region in example: + text_region.region = text_region.region.vertices + + for example in batch: + boxes, texts, scores = parse_relational_results(example, level=merge_level) + boxes, texts, scores = reorder_boxes(boxes, texts, scores, mode="top_left", dbscan_eps=10) + + orig_h, orig_w = original_shape + + if len(boxes) == 0: + boxes = ["nan"] + texts = ["nan"] + scores = ["nan"] + else: + # Convert to numpy array and reshape to (N, 4, 2) for easier processing + boxes_array = np.array(boxes).reshape(-1, 4, 2) + + # Divide X coordinates by orig_w and Y coordinates by orig_h + boxes_array[:, :, 0] = boxes_array[:, :, 0] / orig_w # X coordinates + boxes_array[:, :, 1] = boxes_array[:, :, 1] / orig_h # Y coordinates + boxes = boxes_array.astype(np.float16).tolist() + + for box, text, conf in zip(boxes, texts, scores): + if box == "nan": + break + predictions.append( + { + "text": text, + "confidence": conf, + "left": min(p[0] for p in box), + "upper": max(p[1] for p in box), + "right": max(p[0] for p in box), + "lower": min(p[1] for p in box), + } + ) + + return predictions + + def _save_annotated_image(self, image_path, predictions): + """Saves a copy of the image with bounding boxes overlaid.""" + output_path = os.path.splitext(image_path)[0] + "-annotated" + os.path.splitext(image_path)[1] + + pil_image = Image.open(image_path).convert("RGB") + draw = ImageDraw.Draw(pil_image) + + font = ImageFont.load_default() + small_font = ImageFont.load_default() + + img_width, img_height = pil_image.size + + color = (255, 0, 0) + + for pred in predictions: + if isinstance(pred.get("left"), str) and pred["left"] == "nan": + continue + + left = int(pred["left"] * img_width) + right = int(pred["right"] * img_width) + upper = int(pred["upper"] * img_height) + lower = int(pred["lower"] * img_height) + + confidence = pred["confidence"] + text = pred["text"] + + draw.rectangle([left, lower, right, upper], outline=color, width=2) + + display_text = f"{text}" + conf_text = f"({confidence:.2f})" + + text_y = max(0, upper - 25) + + text_bbox = draw.textbbox((left, text_y), display_text, font=font) + conf_bbox = draw.textbbox((left, text_y + 18), conf_text, font=small_font) + + draw.rectangle( + [ + text_bbox[0] - 2, + text_bbox[1] - 2, + text_bbox[2] + 2, + text_bbox[3] + 2, + ], + fill=(255, 255, 255, 180), + outline=color, + ) + draw.rectangle( + [ + conf_bbox[0] - 2, + conf_bbox[1] - 2, + conf_bbox[2] + 2, + conf_bbox[3] + 2, + ], + fill=(255, 255, 255, 180), + outline=color, + ) + + draw.text((left, text_y), display_text, fill=color, font=font) + draw.text((left, text_y + 18), conf_text, fill=color, font=small_font) + + pil_image.save(output_path) + + print(f"Annotated image saved to: {output_path}") + print( + f"Total predictions overlaid: {len([p for p in predictions if not (isinstance(p.get('left'), str) and p['left'] == 'nan')])}" + ) diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/__init__.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb59f61a04031738fb505e42b3f5c821c9f54ad5 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/data_container.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/data_container.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bb425ef39b7a4e27a80555f742b1361ec76a23 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/data_container.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Data container containing uniform calls to data manipulation functions.""" + + +class DataContainer(object): + """ + Business object helper that allows for low-level data operations to be uniformly called. + + Essentially, this prevents us from having to write a wrapper call for low level objects. + + e.g. The `Path` object has two `Polyline`s which support these operations. By supporting the + `__iter__` function, it means that it can inherit from this object instead of implementing + and forwarding all of these calls to the `Polyline` object. Similarly, the `Example` object + contains multiple `Path`s, and the `Batch` contains multiple `Example`s. + """ + + def __iter__(self): + """Make iterable.""" + raise NotImplementedError("Subclasses must implement this!") + + def apply_stm(self, *args, **kwargs): + """Applies the homogeneous transformation matrix.""" + for sub_item in self: + sub_item.apply_stm(*args, **kwargs) + + def translate(self, *args, **kwargs): + """Translates a set of vertices.""" + for sub_item in self: + sub_item.translate(*args, **kwargs) + + def scale(self, *args, **kwargs): + """Multiplies a set of vertices by a scale factor.""" + for sub_item in self: + sub_item.scale(*args, **kwargs) + + def rotate(self, *args, **kwargs): + for sub_item in self: + sub_item.rotate(*args, **kwargs) + + def apply(self, fn): + for sub_item in self: + sub_item.apply(fn) + + def validate(self): + return all(sub.validate() for sub in self) + + def mark_dirty(self): + for sub_item in self: + sub_item.mark_dirty() diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/quadrangle.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/quadrangle.py new file mode 100644 index 0000000000000000000000000000000000000000..c8684f6b6481b9a45d3f6e1e564582910d570c57 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/quadrangle.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Quadrangle class.""" + +import logging +from typing import Union + +from shapely.geometry import Polygon +import numpy +import torch + +from nemo_retriever_ocr_cpp import calc_poly_min_rrect, get_poly_bounds_quad + +logger = logging.getLogger(__name__) + + +def apply_single_stm(vertices, stm): + """ + Applies the single homogeneous transformation matrix. + + Args: + vertices (torch tensor): Array of 2D vertices that form the polyline. + stm (torch.tensor): 3x3 homogeneous matrix. + """ + homogenous_vertices = torch.cat((vertices, torch.ones(vertices.shape[0], 1)), dim=1) + transformed = torch.matmul(homogenous_vertices, stm) + norm_factor = 1.0 / transformed[:, 2:] + # Handle divide by zero case. + norm_factor[transformed[:, 2:] == 0] = 0 + return transformed[:, :2].contiguous() * norm_factor + + +class AABB: + def __init__(self, x, y, width, height): + self.x = x + self.y = y + self.width = width + self.height = height + + @property + def area(self): + return self.width * self.height + + def contains(self, x, y): + return ( + x >= self.x and x < (self.x + self.width) and y >= self.y and y < (self.y + self.height) + ) + + def to_quad(self): + vertices = [ + self.x, + self.y, + self.x + self.width, + self.y, + self.x + self.width, + self.y + self.height, + self.x, + self.y + self.height, + ] + return Quadrangle(torch.tensor(vertices, dtype=torch.float32).reshape(4, 2)) + + def __str__(self): + return ( + f"[x: {self.x:0.02f}, y: {self.y:0.02f}, w: {self.width:0.02f}, h: {self.height:0.02f}]" + ) + + +class Quadrangle: + def __init__(self, vertices): + self.vertices = torch.as_tensor(vertices) + + if self.vertices.shape[-1] != 2: + raise ValueError("The vertices must be 2-dimensional!") + + self._reset_cache() + + @staticmethod + def from_size(height, width): + vertices = torch.tensor( + [[0, 0], [width, 0], [width, height], [0, height]], dtype=torch.float32 + ) + + return Quadrangle(vertices) + + def _reset_cache(self): + self._bounds = None + self.__poly = None + self._min_rrect = None + + def clone(self): + return Quadrangle(vertices=self.vertices.clone()) + + def apply_stm(self, stm, *args, **kwargs): + """ + Applies the homogeneous transformation matrix. + + Args: + stm (torch.tensor): 3x3 homogeneous matrix. + """ + self.vertices = apply_single_stm(self.vertices, stm) + self._reset_cache() + + def translate(self, delta_vector): + """ + Translates all of the points by the given 2d delta. + + Args: + delta_vector (torch.tensor): 2d translation vector. + """ + self.vertices += delta_vector + self._reset_cache() + + def scale(self, scale_vector, **kwargs): + """ + Scales the points by the given 2d size vector. + + Args: + scale_vector (torch.tensor): 2d scale vector. + E.g. [2.0, 0.5] scales x by 2 and y by 0.5. + """ + self.vertices *= scale_vector + self._reset_cache() + + def rotate(self, rot_mat): + self.vertices = self.vertices @ rot_mat.t() + self._reset_cache() + + def mark_dirty(self): + self._reset_cache() + + def shortest_edge_length(self) -> float: + return min(self.get_magnitudes()) + + @property + def valid(self): + bds = self.bounds + width = (bds[1, 0] - bds[0, 0]).item() + height = (bds[3, 1] - bds[0, 1]).item() + return width > 0 and height > 0 + + @property + def bounds_quad(self) -> torch.Tensor: + if self._bounds is None: + self._bounds = get_poly_bounds_quad(self.vertices) + return self._bounds + + @property + def _poly(self) -> Polygon: + if self.__poly is None: + self.__poly = Polygon(self.vertices.numpy()) + return self.__poly + + @property + def min_rrect(self) -> torch.Tensor: + """Returns a rotated rect set of vertices""" + if self._min_rrect is None: + self._min_rrect = calc_poly_min_rrect(self.vertices) + return self._min_rrect + + @property + def area(self): + return self._poly.area + + def get_intersection(self, other_quad): + return self._poly.intersection(other_quad._poly) + + def get_union(self, other_quad): + return self._poly.union(other_quad._poly) + + def orient(self): + vertices = self.vertices.numpy() + + # print('--------------') + # print('Input Vertices:\n{}'.format(vertices)) + + if not is_clockwise(vertices): + # print('Counter Clockwise') + vertices = numpy.flip(vertices, 0) + # else: + # print('Clockwise') + + # Super lazy, but top-left will be considered quite literally + start_idx = numpy.argmin(vertices.sum(axis=1)) + + # print('Start Idx: {}'.format(start_idx)) + + out_verts = numpy.empty_like(vertices) + for i in range(4): + d_i = (start_idx + i) % 4 + out_verts[i] = vertices[d_i] + + # print('Out Vertices:\n{}\n'.format(out_verts)) + + self.vertices = torch.from_numpy(out_verts) + + def get_magnitudes(self): + return get_magnitudes(self.vertices.numpy()) + + def apply(self, fn): + fn(self) + + def validate(self): + return torch.all(torch.isfinite(self.vertices)) + + +def is_clockwise(vertices) -> bool: + v = 0 + for i in range(4): + d_i = (i + 1) % 4 + v += (vertices[d_i, 0] - vertices[i, 0]) * (vertices[d_i, 1] + vertices[i, 1]) + + return v < 0 + + +mag_inds_b = [(i + 1) % 4 for i in range(4)] + + +def get_magnitudes(vertices: Union[torch.Tensor, numpy.ndarray]): + if isinstance(vertices, numpy.ndarray): + dkwd = {"axis": -1} + sqrt = numpy.sqrt + else: + dkwd = {"dim": -1} + sqrt = torch.sqrt + + b = vertices[..., mag_inds_b, :] + a = vertices + + d_v = (b - a) ** 2 + d_v = d_v.sum(**dkwd) + d_v = sqrt(d_v) + + return d_v + + +def box_2_quad(bds: torch.Tensor): + tl = bds[..., 0, :] + br = bds[..., 1, :] + + p1 = tl + p2 = torch.stack((br[..., 0], tl[..., 1]), dim=-1) + p3 = br + p4 = torch.stack((tl[..., 0], br[..., 1]), dim=-1) + + ret = torch.stack((p1, p2, p3, p4), dim=-2) + + return ret + + +def get_quad_height(vertices: Union[torch.Tensor, numpy.ndarray]): + mags = get_magnitudes(vertices) + + return (mags[..., 1] + mags[..., 3]) / 2 diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/relation_graph.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/relation_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..219dedec8ed20687a5006f4d216958deb075adf4 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/relation_graph.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Optional, List, Tuple, TYPE_CHECKING + +# This prevents a cyclic import while allowing typing +if TYPE_CHECKING: + from data.text_region import TextRegion, Example + + +import torch + +from nemo_retriever_ocr_cpp import get_rel_continuation_cos as _get_rel_continuation_cos + +NEW_LINE_THRESHOLD = math.cos(80 / 180 * math.pi) + + +def _clone_paragraph(paragraph: List[List[int]]): + return [list(s) for s in paragraph] + + +class RelationGraph: + def __init__(self, paragraphs: Optional[List[List[List[int]]]] = None): + self.paragraphs = paragraphs if paragraphs is not None else [] + + def __len__(self): + return len(self.paragraphs) + + def __getitem__(self, idx): + return self.paragraphs[idx] + + def __iter__(self): + return iter(self.paragraphs) + + def __str__(self): + return str(self.paragraphs) + + def __repr__(self): + return str(self) + + @property + def word_count(self): + ct = 0 + for paragraph in self: + ct += sum(len(s) for s in paragraph) + return ct + + @property + def is_proper_unique(self): + items = set() + for paragraph in self: + for sentence in paragraph: + for word in sentence: + if word in items: + return False + items.add(word) + return True + + def is_valid_for_example(self, example: "Example"): + if self.word_count != len(example): + return False + return self.is_proper_unique + + def non_trivial(self): + nt = [] + for paragraph in self: + num_words = sum(len(s) for s in paragraph) + if num_words > 1: + nt.append(_clone_paragraph(paragraph)) + + return RelationGraph(nt) + + def non_trivial_paragraph(self): + nt = [] + for paragraph in self: + if len(paragraph) > 1: + nt.append(_clone_paragraph(paragraph)) + return RelationGraph(nt) + + def multi_line(self, example: "Example"): + ml = [] + for paragraph in self: + if is_relation_multiline(example, paragraph): + ml.append(_clone_paragraph(paragraph)) + return RelationGraph(ml) + + def flatten(self): + ret = [] + for paragraph in self: + flat = sum(paragraph, []) + if flat: + ret.append([flat]) + return RelationGraph(ret) + + def isolate_sentences(self): + ret = [] + for paragraph in self: + for sentence in paragraph: + ret.append([list(sentence)]) + return RelationGraph(ret) + + def split_lines(self, example: "Example"): + fg = self.flatten() + + ret = [] + for paragraph in fg: + sentence = paragraph[0] + out_para = [] + out_sentence = [sentence[0]] + for i in range(1, len(sentence)): + region_a = example[out_sentence[-1]] + region_b = example[sentence[i]] + if is_new_line(region_a, region_b): + out_para.append(out_sentence) + out_sentence = [] + out_sentence.append(sentence[i]) + if out_sentence: + out_para.append(out_sentence) + ret.append(out_para) + return RelationGraph(ret) + + def graph_to_sparse_tensor(self): + num_words = self.word_count + ret = torch.full((num_words,), -1, dtype=torch.int64) + flat = self.flatten() + for para in flat: + sent = para[0] + for i in range(1, len(sent)): + p_idx = sent[i - 1] + c_idx = sent[i] + ret[p_idx] = c_idx + return ret + + def get_connection_pairs(self, example: "Example" = None): + pairs = [] + for paragraph in self.flatten(): + sentence = paragraph[0] + for i in range(1, len(sentence)): + pairs.append((sentence[i - 1], sentence[i])) + + if example is not None: + pairs = self.filter_valid_pairs(example, pairs) + + return pairs + + def filter_valid_pairs(self, example: "Example", pairs: List[Tuple[int, int]]): + out_pairs = [] + for a, b in pairs: + a_valid = example[a].valid + b_valid = example[b].valid + + if a_valid and b_valid: + out_pairs.append((a, b)) + return out_pairs + + +def is_relation_multiline(example: "Example", paragraph: List[List[int]]): + """ + Look to see if the absolute angle between two relations is greater than 80 degrees + """ + flat_text = sum(paragraph, []) + + for i in range(1, len(flat_text)): + region_a = example[flat_text[i - 1]] + region_b = example[flat_text[i]] + + if is_new_line(region_a, region_b): + return True + + return False + + +def get_rel_continuation_cos(region_a: "TextRegion", region_b: "TextRegion"): + try: + rect_a = region_a.region.min_rrect + rect_b = region_b.region.min_rrect + except RuntimeError: + return 1.0 + + return _get_rel_continuation_cos(rect_a, rect_b) + + +def is_new_line(region_a: "TextRegion", region_b: "TextRegion"): + cos_t = get_rel_continuation_cos(region_a, region_b) + + return cos_t < NEW_LINE_THRESHOLD diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/text_region.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/text_region.py new file mode 100644 index 0000000000000000000000000000000000000000..68c81f8b922ebb1cade6cb684d47009608e0b5ea --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/text_region.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Text region class.""" + +from typing import List, Iterator, Optional + +import torch + +from nemo_retriever_ocr.inference.post_processing.data.data_container import DataContainer +from nemo_retriever_ocr.inference.post_processing.data.quadrangle import Quadrangle +from nemo_retriever_ocr.inference.post_processing.data.relation_graph import RelationGraph + +from nemo_retriever_ocr_cpp import text_region_grouping + + +HEUR_HORIZONTAL_TOLERANCE = 2.0 +HEUR_VERTICAL_TOLERANCE = 0.5 + + +class TextRegion(DataContainer): + def __init__( + self, + region: Quadrangle, + text: str, + valid: Optional[bool] = True, + language=None, + confidence=1, + ): + self.region = region + self.text = text + self.valid = valid + self.quad_prob = 1 + self.text_prob = 1 + self.confidence = confidence + self.language = language + + def __iter__(self) -> Iterator[Quadrangle]: + yield self.region + + def to_string(self, indent="") -> str: + """Creates a string representation of the region for easy printing.""" + vertices = self.region + if isinstance(vertices, Quadrangle): + vertices = vertices.vertices + + ret = '{}Region (T="{}", Valid={}) {}'.format(indent, self.text, self.valid, vertices) + return ret + + def __str__(self): + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def clone(self): + ret = TextRegion( + region=self.region.clone(), + text=self.text, + valid=self.valid, + ) + ret.quad_prob = self.quad_prob + ret.text_prob = self.text_prob + ret.confidence = self.confidence + return ret + + +class TextBlock(DataContainer): + def __init__(self, regions: List[TextRegion]): + self.regions = regions + + def __len__(self) -> int: + return len(self.regions) + + def __iter__(self) -> Iterator[TextRegion]: + return iter(self.regions) + + def __getitem__(self, idx) -> TextRegion: + return self.regions[idx] + + def __str__(self): + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def to_string(self, indent="") -> str: + ret = indent + " ".join(tr.text for tr in self.regions) + return ret + + +class Example(DataContainer): + def __init__( + self, + regions: List[TextRegion], + label=None, + is_synthetic=False, + relation_graph: Optional[RelationGraph] = None, + ): + self.regions = regions + self.valid = True + self.label = label + self.is_synthetic = is_synthetic + self.relation_graph = relation_graph + self._coalesced = None + self.bounds: Quadrangle = None + self.blocks: List[TextBlock] = [] + + def __len__(self): + return len(self.regions) + + def __iter__(self) -> Iterator[TextRegion]: + return iter(self.regions) + + def __getitem__(self, idx) -> TextRegion: + return self.regions[idx] + + def to_string(self, indent="") -> str: + ret = "{}Example:\n".format(indent) + for r in self.regions: + ret += "{}\n".format(r.to_string(indent=indent + "\t")) + return ret + + def __str__(self): + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def clone(self): + ret = Example([r.clone() for r in self.regions], label=self.label) + ret.valid = self.valid + return ret + + def compute_text_relations(self): + if self.relation_graph is not None: + return self.relation_graph + Batch([self]).compute_text_relations() + return self.relation_graph + + def prune_invalid_relations(self): + if self.relation_graph is None: + return + + def is_valid(sentence): + return any(self[w_idx].valid for w_idx in sentence) + + nt = [] + for paragraph in self.relation_graph: + flat_para = sum(paragraph, []) + if is_valid(flat_para): + nt.append(paragraph) + self.relation_graph = RelationGraph(nt) + + def relations_str(self, graph: RelationGraph = None): + if graph is None: + graph = self.relation_graph + + if graph is None: + return [r.text for r in self] + + ret = [] + for paragraph in graph: + st = "\n".join(" ".join(self[i].text for i in s) for s in paragraph) + ret.append(st) + + return ret + + def coalesce_homogeneous(self): + """ + Coalesce all of the quad buffers into a single tensor, and make the tensors + homogeneous + """ + if self._coalesced is not None: + return self._coalesced + + num_vertices = 0 + tensors = [] + for tr in self: + tensors.append(tr.region.vertices) + num_vertices += tensors[-1].shape[0] + + coal = torch.empty(num_vertices, 3, dtype=torch.float32) + coal[:, 2] = 1 + if tensors: + torch.cat(tensors, dim=0, out=coal[:, :2]) + + self._coalesced = coal + offset = 0 + for tr in self: + ex_verts = tr.region.vertices + coal_slice = coal[offset : offset + ex_verts.shape[0], :2] + tr.region.vertices = coal_slice + offset += ex_verts.shape[0] + + return self._coalesced + + def apply_stm(self, stm, perspective=False, **kwargs): + if self.bounds is not None: + self.bounds.apply_stm(stm, perspective, **kwargs) + + if self._coalesced is None: + return super().apply_stm(stm, perspective=perspective, **kwargs) + + tx = torch.matmul(self._coalesced, stm) + self._coalesced.copy_(tx) + + if perspective: + self._coalesced /= self._coalesced[:, 2:] + + def translate(self, delta_vector): + if self.bounds is not None: + self.bounds.translate(delta_vector) + + if self._coalesced is None: + return super().translate(delta_vector) + + self._coalesced[:, :2] += delta_vector + + def scale(self, scale_vector, **kwargs): + if self.bounds is not None: + self.bounds.scale(scale_vector, **kwargs) + + if self._coalesced is None: + return super().scale(scale_vector, **kwargs) + + self._coalesced[:, :2] *= scale_vector + + def rotate(self, rot_mat): + if self.bounds is not None: + self.bounds.rotate(rot_mat) + + if self._coalesced is None: + return super().rotate(rot_mat) + + view = self._coalesced[:, :2] + tx = torch.matmul(view, rot_mat.t()) + view.copy_(tx) + + def validate(self): + coal = self.coalesce_homogeneous() + return torch.all(torch.isfinite(coal[:, :2])).item() + + +class Batch(DataContainer): + def __init__(self, examples: List[Example]): + self.examples = examples + + def __len__(self): + return len(self.examples) + + def __iter__(self) -> Iterator[Example]: + return iter(self.examples) + + def __getitem__(self, idx) -> Example: + return self.examples[idx] + + def to_string(self, indent="") -> str: + ret = "{}Batch:\n".format(indent) + for ex in self.examples: + ret += "{}\n".format(ex.to_string(indent=indent + "\t")) + return ret + + def __str__(self): + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def clone(self): + return Batch([ex.clone() for ex in self.examples]) + + def compute_text_relations(self): + graphs = self.get_text_relations() + + for i, ex in enumerate(self): + ex.relation_graph = graphs[i] + + def get_text_relations(self): + cts = torch.zeros(len(self), dtype=torch.int64) + total_ct = 0 + for i, ex in enumerate(self): + cts[i] = len(ex) + total_ct += len(ex) + + quads = torch.zeros(total_ct, 4, 2, dtype=torch.float32) + off = 0 + for ex in self: + for tr in ex: + quads[off] = tr.region.vertices + off += 1 + + all_relations = text_region_grouping( + quads, + cts, + horizontal_tolerance=HEUR_HORIZONTAL_TOLERANCE, + vertical_tolerance=HEUR_VERTICAL_TOLERANCE, + ) + + graphs = [] + + for ex, phrases in zip(self, all_relations): + paragraphs = [] + for i in range(len(phrases) - 1, -1, -1): + r_flat = [] + for line in phrases[i]: + r_flat.extend(line) + + any_valid = False + for tr_idx in r_flat: + if ex[tr_idx].valid: + any_valid = True + break + + if not any_valid: + del phrases[i] + else: + curr_sentence = [] + sentences = [] + for word_idx in r_flat: + tr = ex[word_idx] + curr_sentence.append(word_idx) + if tr.text.endswith((".", "?", "!")): + sentences.append(curr_sentence) + curr_sentence = [] + if curr_sentence: + sentences.append(curr_sentence) + paragraphs.append(sentences) + + graphs.append(RelationGraph(paragraphs)) + + return graphs + + def get_valid_list(self): + return [sub.validate() for sub in self] diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/worker_messages.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/worker_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff5b8fd9cc7f3b242b33d998a62120f729da61d --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/data/worker_messages.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +class WorkerMessage: + def __init__(self): + pass + + #### + # Pickle methods + #### + def __getstate__(self): + state = dict() + self.build_state(state) + return state + + def __setstate__(self, state): + self.update_state(state) + + #### + + def build_state(self, state): + pass + + def update_state(self, state): + for k, v in state.items(): + setattr(self, k, v) + + +class TargetEncoderMessage(WorkerMessage): + def __init__(self, name): + self.name = name + + def build_state(self, state): + state["name"] = self.name diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/indirect_grid_sample.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/indirect_grid_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..87bd364249b5f23e3b9188728c2c02218b3a78a0 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/indirect_grid_sample.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from torch.autograd import Function +from nemo_retriever_ocr_cpp import ( + indirect_grid_sample_forward, + indirect_grad_sample_backward, +) + +logger = logging.getLogger(__name__) + + +class IndirectGridSampleFunction(Function): + @staticmethod + def forward(ctx, input, grid, input_indices, mode="bilinear"): + val = indirect_grid_sample_forward(input, grid, input_indices, mode) + + ctx.mode = mode + ctx.save_for_backward(input, grid, input_indices) + + return val + + @staticmethod + def backward(ctx, grad_output): + input, grid, input_indices = ctx.saved_tensors + + grad_input, grad_grid = indirect_grad_sample_backward( + grad_output, input, grid, input_indices, ctx.mode + ) + + return grad_input, grad_grid, None, None + + +def indirect_grid_sample( + input: torch.Tensor, grid: torch.Tensor, input_indices: torch.Tensor, mode="bilinear" +): + return IndirectGridSampleFunction.apply(input, grid, input_indices, mode) + + +class IndirectGridSample(torch.nn.Module): + def __init__(self, mode="bilinear"): + super().__init__() + self.mode = mode + + def forward(self, input: torch.Tensor, grid: torch.Tensor, input_indices: torch.Tensor): + return indirect_grid_sample(input, grid, input_indices, self.mode) diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/quad_rectify.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/quad_rectify.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc208bb3b83e3973e006a0f853630019432bf9b --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/quad_rectify.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.autograd import Function +from nemo_retriever_ocr_cpp import ( + quad_rectify_backward, + quad_rectify_calc_quad_width, + quad_rectify_forward, +) + + +def _quad_wrap(fn, quads, *args, **kwargs): + orig_type = quads.dtype + if quads.dtype == torch.float16: + quads = quads.to(torch.float32) + ret = fn(quads, *args, **kwargs) + + return ret.to(orig_type) + + +def _quad_rectify_forward(*args, **kwargs): + return _quad_wrap(quad_rectify_forward, *args, **kwargs) + + +def _quad_rectify_backward(*args, **kwargs): + return _quad_wrap(quad_rectify_backward, *args, **kwargs) + + +class QuadRectifyFunction(Function): + @staticmethod + def forward( + ctx, + quads, + image_height, + image_width, + output_height, + output_width, + round_factor=16, + isotropic=True, + ): + if output_width <= 0: + widths = quad_rectify_calc_quad_width( + quads.float(), output_height, round_factor, -output_width + ) + output_width = widths.max().item() + + output = _quad_rectify_forward( + quads, + int(image_height), + int(image_width), + int(output_height), + int(output_width), + isotropic=isotropic, + ) + ctx.save_for_backward(quads) + ctx.image_shape = (int(image_height), int(image_width)) + + return output + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + if ctx.needs_input_grad[0]: + grad_input = _quad_rectify_backward( + ctx.saved_variables[0], grad_output, *ctx.image_shape + ) + return grad_input, None, None, None, None, None, None + + +class QuadRectify(torch.nn.Module): + def __init__(self, output_height, output_width, round_factor=16, isotropic=True): + super().__init__() + self.output_height = output_height + self.output_width = output_width + self.round_factor = round_factor + self.isotropic = isotropic + + def forward(self, quads, image_height, image_width): + return QuadRectifyFunction.apply( + quads, + image_height, + image_width, + self.output_height, + self.output_width, + self.round_factor, + self.isotropic, + ) diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/research_ops.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/research_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2eeac9ef8b8ee4b356c4a020918049024e3a1061 --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/research_ops.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pandas as pd +from sklearn.cluster import DBSCAN + + +def parse_relational_results(result, level="sentence"): + """ + Parses the relational results from the OCR model. + Supported levels: + - "word" returns a list of words, each with a bounding box, text, and confidence. + - "sentence" returns a list of sentences, each with a bounding box, text, and confidence. + - "paragraph" returns a list of paragraphs, each with a bounding box, text, and confidence. + + Args: + result (object): The result object from the OCR model. + level (str, optional): The level to parse the results to. Defaults to "sentence". + + Returns: + np array [N x 4 x 2]: The bounding boxes of the OCR results. + np array [N]: The text of the OCR results + np array [N]: The confidence scores of the OCR results + """ + if level not in ["word", "sentence", "paragraph"]: + raise ValueError( + f"Invalid level: {level}. Supported levels are 'word', 'sentence', and 'paragraph'." + ) + results = [] + for block_ids in result.relation_graph: + sentences = [] + for sentence_ids in block_ids: + regions = [result.regions[idx] for idx in sentence_ids] + bboxes = [region.region.numpy() for region in regions] + texts = [region.text for region in regions] + confs = [region.confidence for region in regions] + + if level == "word": + for bbox, text, conf in zip(bboxes, texts, confs): + results.append( + { + "bbox": bbox, + "text": text, + "confidence": conf, + } + ) + else: + bboxes = np.stack(bboxes) + xmin = bboxes[:, :, 0].min().item() + ymin = bboxes[:, :, 1].min().item() + xmax = bboxes[:, :, 0].max().item() + ymax = bboxes[:, :, 1].max().item() + + sentences.append( + { + "bbox": [ + [xmin, ymin], + [xmax, ymin], + [xmax, ymax], + [xmin, ymax], + ], + "text": " ".join(texts), + "confidence": np.mean(confs), + } + ) + if level == "word": + pass + elif level == "sentence": + results += sentences + else: + bboxes = np.stack([s["bbox"] for s in sentences]) + texts = [s["text"] for s in sentences] + confs = [s["confidence"] for s in sentences] + xmin = bboxes[:, :, 0].min().item() + ymin = bboxes[:, :, 1].min().item() + xmax = bboxes[:, :, 0].max().item() + ymax = bboxes[:, :, 1].max().item() + results.append( + { + "bbox": [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]], + "text": " ".join(texts), + "confidence": np.mean(confs), + } + ) + + boxes = np.array([r["bbox"] for r in results]) + texts = np.array([r["text"] for r in results]) + confidences = np.array([r["confidence"] for r in results]) + return boxes, texts, confidences + + +def reorder_boxes(boxes, texts, confs, mode="center", dbscan_eps=10): + """ + Reorders the boxes in reading order. + If mode is "center", the boxes are reordered using bbox center. + If mode is "top_left", the boxes are reordered using the top left corner. + If dbscan_eps is not 0, the boxes are reordered using DBSCAN clustering. + + Args: + boxes (np array [n x 4 x 2]): The bounding boxes of the OCR results. + texts (np array [n]): The text of the OCR results. + confs (np array [n]): The confidence scores of the OCR results. + mode (str, optional): The mode to reorder the boxes. Defaults to "center". + dbscan_eps (float, optional): The epsilon parameter for DBSCAN. Defaults to 10. + + Returns: + np array [n x 4 x 2]: The reordered bounding boxes. + np array [n]: The reordered texts. + np array [n]: The reordered confidence scores. + """ + df = pd.DataFrame( + [[b, t, c] for b, t, c in zip(boxes, texts, confs)], + columns=["bbox", "text", "conf"], + ) + + if mode == "center": + df["x"] = df["bbox"].apply(lambda box: (box[0][0] + box[2][0]) / 2) + df["y"] = df["bbox"].apply(lambda box: (box[0][1] + box[2][1]) / 2) + elif mode == "top_left": + df["x"] = df["bbox"].apply(lambda box: (box[0][0])) + df["y"] = df["bbox"].apply(lambda box: (box[0][1])) + + if dbscan_eps: + do_naive_sorting = False + try: + dbscan = DBSCAN(eps=dbscan_eps, min_samples=1) + dbscan.fit(df["y"].values[:, None]) + df["cluster"] = dbscan.labels_ + df["cluster_centers"] = df.groupby("cluster")["y"].transform("mean").astype(int) + df = df.sort_values(["cluster_centers", "x"], ascending=[True, True], ignore_index=True) + except ValueError: + do_naive_sorting = True + else: + do_naive_sorting = True + + if do_naive_sorting: + df["y"] = np.round((df["y"] - df["y"].min()) // 5, 0) + df = df.sort_values(["y", "x"], ascending=[True, True], ignore_index=True) + + bboxes = [p.tolist() for p in df["bbox"].values.tolist()] + texts = df["text"].values.tolist() + confs = df["conf"].values.tolist() + return bboxes, texts, confs diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/rrect_to_quads.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/rrect_to_quads.py new file mode 100644 index 0000000000000000000000000000000000000000..796b92983b71a8fc909f020e136713c7d8a3ffec --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/post_processing/rrect_to_quads.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.autograd import Function +from nemo_retriever_ocr_cpp import rrect_to_quads, rrect_to_quads_backward + + +class RRectToQuadsFunction(Function): + @staticmethod + def forward(ctx, rrects: torch.Tensor, cell_size: float): + ctx.save_for_backward(rrects) + + return rrect_to_quads(rrects, cell_size) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + rrects = ctx.saved_variables[0] + + grad_input = rrect_to_quads_backward(rrects, grad_output) + + return grad_input, None + + +class RRectToQuads(torch.nn.Module): + def __init__(self, cell_size: float): + super().__init__() + self.cell_size = cell_size + + def forward(self, rrects: torch.Tensor): + return RRectToQuadsFunction.apply(rrects, self.cell_size) diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/pre_processing.py b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/pre_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..f42e253d3f9714bf0f0b534afac616f000329e5a --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr/inference/pre_processing.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +import torch.nn.functional as F + + +def pad_to_square(img: torch.Tensor, target_length: int, how="center") -> torch.Tensor: + """ + Pads the input image to a square shape with the specified size. + + Args: + img (torch.Tensor): Input image tensor of shape (C, H, W). + size (int): The target size for both height and width. + + Returns: + torch.Tensor: Padded image tensor of shape (C, size, size). + """ + _, h, w = img.shape + + if how == "center": + pad_h = (target_length - h) // 2 + pad_w = (target_length - w) // 2 + return F.pad( + img, (pad_w, target_length - w - pad_w, pad_h, target_length - h - pad_h), value=1.0 + ) + elif how == "bottom_right": + pad_h = target_length - h + pad_w = target_length - w + return F.pad(img, (0, pad_w, 0, pad_h), value=1.0) + else: + raise ValueError(f"Unsupported padding method: {how}") + + +def interpolate_and_pad( + images: torch.Tensor, pad_color: torch.Tensor, infer_length: int +) -> torch.Tensor: + """ + Interpolates the input images to a specified height and pads them to a specified width. + + Args: + images (torch.Tensor): Input image tensor of shape (B, C, H, W). + infer_height (int): The target height for interpolation. + pad_infer_width (int): The target width for padding. + + Returns: + torch.Tensor: Interpolated and padded image tensor of shape (B, C, infer_height, pad_infer_width). + """ + pad_infer_width = int(math.ceil(infer_length / 128) * 128) + + rs_images = F.interpolate( + images, size=(infer_length, infer_length), mode="bilinear", align_corners=True + ) + + padded = ( + pad_color.reshape(1, -1, 1, 1) + .expand(images.shape[0], -1, infer_length, pad_infer_width) + .contiguous() + ) + padded[..., :infer_length].copy_(rs_images) + + return padded diff --git a/nemo-retriever-ocr/src/nemo_retriever_ocr_cpp/__init__.py b/nemo-retriever-ocr/src/nemo_retriever_ocr_cpp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83addf23f2d1a5ba1b6f0de78fa9a7cba0c0333e --- /dev/null +++ b/nemo-retriever-ocr/src/nemo_retriever_ocr_cpp/__init__.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import importlib + +# Ensure torch is imported so its shared libs (e.g., libc10.so) are available +try: + import torch # noqa: F401 +except Exception: # torch may be cpu-only/cuda-only, still attempt import of extension + pass + +from ._nemo_retriever_ocr_cpp import * # noqa: F403 \ No newline at end of file diff --git a/nemo-retriever-ocr/tests/test_nms.py b/nemo-retriever-ocr/tests/test_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..984613a3ad9fbf42c26b27d58876f60916ea575e --- /dev/null +++ b/nemo-retriever-ocr/tests/test_nms.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from nemo_retriever_ocr_cpp import quad_non_maximal_suppression + +quads = [ + [[0, 2], [2, 2], [2, 0], [0, 0]], + [[1, 3], [3, 3], [3, 1], [1, 1]], + [[6, 6], [10, 6], [10, 5], [6, 5]], + [[6, 5.5], [8, 6], [10, 5.5], [8, 5]], + [[7, 6], [9, 6], [9, 5], [7, 5]], +] + +quads = torch.tensor([[quads]], dtype=torch.float32).cuda() + +probs = [0.7, 0.8, 0.55, 0.9, 0.85] + +probs = torch.tensor([[probs]], dtype=torch.float32).cuda() + +print("in_quads") +print(quads) +print("in_probs") +print(probs) + +out_quads, out_probs, out_region_counts = quad_non_maximal_suppression( + quads, probs, prob_threshold=0.5, iou_threshold=0.1, kernel_height=2, kernel_width=2 +)[:3] + +print("out_quads") +print(out_quads) +print("out_probs") +print(out_probs) +print("out_region_counts") +print(out_region_counts) diff --git a/nemo-retriever-ocr/tests/test_quad_rectify.py b/nemo-retriever-ocr/tests/test_quad_rectify.py new file mode 100644 index 0000000000000000000000000000000000000000..28055c43d6340e64bf4a2622546b842aa1501d07 --- /dev/null +++ b/nemo-retriever-ocr/tests/test_quad_rectify.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import shutil + +import torch +import torch.nn.functional as F +from torch.autograd import gradcheck +from torchvision.transforms import ToTensor, ToPILImage +from PIL import Image + +from nemo_retriever_ocr.inference.post_processing.quad_rectify import QuadRectify + +example = 3 +isotropic = True + +# root_dir = '/mnt/fsx-datasets-a-1-new/mranzinger/ocr/scene-text/icdar/incidental_text/train' +root_dir = "/home/dcg-adlr-mranzinger-data.cosmos1100/ocr/scene-text/icdar/incidental_text/train" +# root_dir = '/raid/local_datasets/scene-text/icdar/focused_text/train' + +image_path = "{}/images/img_{}.jpg".format(root_dir, example) +label_path = "{}/gt/gt_img_{}.txt".format(root_dir, example) + +shutil.copyfile(image_path, "original.jpg") + +image = Image.open(image_path) + +image = ToTensor()(image) + +print("image", image.shape) + +quads = [] + +with open(label_path, "r") as fd: + for i, line in enumerate(fd.readlines()): + line = line.strip() + print(line) + # Skip the unicode character + if i == 0: + line = line[1:] + line = line.split(",") + coords = [float(t) for t in line[:8]] + + word = line[-1] + + if word != "###": + quads.append(coords) + +quads = torch.tensor(quads).reshape(-1, 4, 2) + +# quads[:,:,0] /= image.shape[-1] +# quads[:,:,1] /= image.shape[-2] + +# print(quads) + +image = image.unsqueeze(0).repeat(quads.shape[0], 1, 1, 1) + +image = image.cuda() +quads = quads.cuda() + +aspect = image.shape[-1] / image.shape[-2] +qr = QuadRectify(60, 400, 0, isotropic=isotropic) +# qr = QuadRectify(60, -1000, aspect) + +grid = qr(quads, *image.shape[2:]) + +print("grid", grid.shape) + +resampled = F.grid_sample(image, grid, align_corners=False) + +print(resampled.shape) + +resampled = resampled.permute(1, 0, 2, 3).contiguous().reshape(3, -1, resampled.shape[-1]) +resampled = resampled.cpu() + +pilOutput = ToPILImage()(resampled) + +pilOutput.save("rectified.jpg") + +# sys.exit(0) + +print("checking gradients") + +quads = quads.double() +cuda_quads = quads.cuda() +quads.requires_grad_() +cuda_quads.requires_grad_() + +qr = QuadRectify(9, 12, isotropic=isotropic) + +print("check GPU gradients:") +# gradcheck(qr.forward, cuda_quads, eps=1e-3, atol=1e-3) +gradcheck(qr.forward, (cuda_quads, *image.shape[2:])) +print("check CPU gradients:") +gradcheck(qr.forward, (quads, *image.shape[2:])) +# #gradcheck(qr.forward, quads, eps=1e-3, atol=1e-3) +# gradcheck(qr.forward, quads) + +print("checking performance...") + +quads = quads.detach() + +quads = torch.rand(256, 4, 2, dtype=torch.float64) +cuda_quads = quads.cuda() + +qr = QuadRectify(8, 60, isotropic=isotropic) + +# num_passes = 500 +# for target_device in [torch.device('cuda:0'), torch.device('cpu')]: +# for target_type in [torch.float16, torch.float32, torch.float64]: +# test_quads = quads.to(target_type).to(target_device).clone() + +# test_quads.requires_grad_() + +# print('type:', target_type, 'device:', target_device) + +# try: +# with torch.autograd.profiler.profile(use_cuda=test_quads.is_cuda) as prof: +# for _ in range(num_passes): +# g = qr(test_quads) +# l = (1 - g.mean()) +# l.backward() +# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + +# # print('\tforward...') +# # torch.cuda.synchronize() +# # start_time = time.time() +# # for i in range(num_passes): +# # qr(test_quads).sum() +# # torch.cuda.synchronize() +# # end_time = time.time() +# # fp_sec_per_call = (end_time - start_time) / num_passes +# # print('\t\tTook', fp_sec_per_call, "sec/call") + +# # print('\tbackward...') +# # torch.cuda.synchronize() +# # start_time = time.time() +# # for i in range(num_passes): +# # g = qr(test_quads) +# # g.sum().backward() +# # torch.cuda.synchronize() +# # end_time = time.time() +# # bp_sec_per_call = (end_time - start_time) / num_passes - fp_sec_per_call +# # print('\t\tTook', bp_sec_per_call, "sec/call") +# except Exception as e: +# print('\t\t', e) + +# # for target_type in [torch.float16, torch.float32]: +# # print('DType:', target_type) +# # h_cuda_quads = cuda_quads.clone().to(target_type) +# # h_cuda_quads.requires_grad_() + +# # v = qr(h_cuda_quads).sum() +# # v.backward() + +# # print(h_cuda_quads.grad) diff --git a/nemo-retriever-ocr/tests/test_rrect_to_quads.py b/nemo-retriever-ocr/tests/test_rrect_to_quads.py new file mode 100644 index 0000000000000000000000000000000000000000..a5825e03eea5f11cd98f307ca2ee9f224f810e9e --- /dev/null +++ b/nemo-retriever-ocr/tests/test_rrect_to_quads.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +from torch.autograd import gradcheck + + +from nemo_retriever_ocr.inference.post_processing.rrect_to_quads import RRectToQuads + + +def get_rrects(b, h, w): + rrects = torch.rand(b, h, w, 5, dtype=torch.float64) + rrects[:, :, :, :4] *= 10 + rrects[:, :, :, 4] *= 2 * math.pi + return rrects + + +rrects = get_rrects(2, 5, 5) + +cell_size = 4 +r2q = RRectToQuads(cell_size) + +quads = r2q(rrects) + +rrects.requires_grad_() + +print("check CPU gradients") +gradcheck(r2q.forward, rrects) + +rrects = get_rrects(4, 10, 10) +rrects.requires_grad_() + +print("check GPU gradients") +gradcheck(r2q.forward, rrects) diff --git a/nemo-retriever-ocr/uv.lock b/nemo-retriever-ocr/uv.lock new file mode 100644 index 0000000000000000000000000000000000000000..6f497319f98b6d1ebfd9e221704e18404685a1af --- /dev/null +++ b/nemo-retriever-ocr/uv.lock @@ -0,0 +1,864 @@ +version = 1 +revision = 3 +requires-python = "==3.12.*" + +[[package]] +name = "asttokens" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978, upload-time = "2024-11-30T04:30:14.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, +] + +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114, upload-time = "2023-08-12T20:38:17.776Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249, upload-time = "2023-08-12T20:38:16.269Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "coverage" +version = "7.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/38/ee22495420457259d2f3390309505ea98f98a5eed40901cf62196abad006/coverage-7.11.0.tar.gz", hash = "sha256:167bd504ac1ca2af7ff3b81d245dfea0292c5032ebef9d66cc08a7d28c1b8050", size = 811905, upload-time = "2025-10-15T15:15:08.542Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/db/86f6906a7c7edc1a52b2c6682d6dd9be775d73c0dfe2b84f8923dfea5784/coverage-7.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9c49e77811cf9d024b95faf86c3f059b11c0c9be0b0d61bc598f453703bd6fd1", size = 216098, upload-time = "2025-10-15T15:13:02.916Z" }, + { url = "https://files.pythonhosted.org/packages/21/54/e7b26157048c7ba555596aad8569ff903d6cd67867d41b75287323678ede/coverage-7.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a61e37a403a778e2cda2a6a39abcc895f1d984071942a41074b5c7ee31642007", size = 216331, upload-time = "2025-10-15T15:13:04.403Z" }, + { url = "https://files.pythonhosted.org/packages/b9/19/1ce6bf444f858b83a733171306134a0544eaddf1ca8851ede6540a55b2ad/coverage-7.11.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c79cae102bb3b1801e2ef1511fb50e91ec83a1ce466b2c7c25010d884336de46", size = 247825, upload-time = "2025-10-15T15:13:05.92Z" }, + { url = "https://files.pythonhosted.org/packages/71/0b/d3bcbbc259fcced5fb67c5d78f6e7ee965f49760c14afd931e9e663a83b2/coverage-7.11.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:16ce17ceb5d211f320b62df002fa7016b7442ea0fd260c11cec8ce7730954893", size = 250573, upload-time = "2025-10-15T15:13:07.471Z" }, + { url = "https://files.pythonhosted.org/packages/58/8d/b0ff3641a320abb047258d36ed1c21d16be33beed4152628331a1baf3365/coverage-7.11.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:80027673e9d0bd6aef86134b0771845e2da85755cf686e7c7c59566cf5a89115", size = 251706, upload-time = "2025-10-15T15:13:09.4Z" }, + { url = "https://files.pythonhosted.org/packages/59/c8/5a586fe8c7b0458053d9c687f5cff515a74b66c85931f7fe17a1c958b4ac/coverage-7.11.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:4d3ffa07a08657306cd2215b0da53761c4d73cb54d9143b9303a6481ec0cd415", size = 248221, upload-time = "2025-10-15T15:13:10.964Z" }, + { url = "https://files.pythonhosted.org/packages/d0/ff/3a25e3132804ba44cfa9a778cdf2b73dbbe63ef4b0945e39602fc896ba52/coverage-7.11.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a3b6a5f8b2524fd6c1066bc85bfd97e78709bb5e37b5b94911a6506b65f47186", size = 249624, upload-time = "2025-10-15T15:13:12.5Z" }, + { url = "https://files.pythonhosted.org/packages/c5/12/ff10c8ce3895e1b17a73485ea79ebc1896a9e466a9d0f4aef63e0d17b718/coverage-7.11.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fcc0a4aa589de34bc56e1a80a740ee0f8c47611bdfb28cd1849de60660f3799d", size = 247744, upload-time = "2025-10-15T15:13:14.554Z" }, + { url = "https://files.pythonhosted.org/packages/16/02/d500b91f5471b2975947e0629b8980e5e90786fe316b6d7299852c1d793d/coverage-7.11.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:dba82204769d78c3fd31b35c3d5f46e06511936c5019c39f98320e05b08f794d", size = 247325, upload-time = "2025-10-15T15:13:16.438Z" }, + { url = "https://files.pythonhosted.org/packages/77/11/dee0284fbbd9cd64cfce806b827452c6df3f100d9e66188e82dfe771d4af/coverage-7.11.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:81b335f03ba67309a95210caf3eb43bd6fe75a4e22ba653ef97b4696c56c7ec2", size = 249180, upload-time = "2025-10-15T15:13:17.959Z" }, + { url = "https://files.pythonhosted.org/packages/59/1b/cdf1def928f0a150a057cab03286774e73e29c2395f0d30ce3d9e9f8e697/coverage-7.11.0-cp312-cp312-win32.whl", hash = "sha256:037b2d064c2f8cc8716fe4d39cb705779af3fbf1ba318dc96a1af858888c7bb5", size = 218479, upload-time = "2025-10-15T15:13:19.608Z" }, + { url = "https://files.pythonhosted.org/packages/ff/55/e5884d55e031da9c15b94b90a23beccc9d6beee65e9835cd6da0a79e4f3a/coverage-7.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:d66c0104aec3b75e5fd897e7940188ea1892ca1d0235316bf89286d6a22568c0", size = 219290, upload-time = "2025-10-15T15:13:21.593Z" }, + { url = "https://files.pythonhosted.org/packages/23/a8/faa930cfc71c1d16bc78f9a19bb73700464f9c331d9e547bfbc1dbd3a108/coverage-7.11.0-cp312-cp312-win_arm64.whl", hash = "sha256:d91ebeac603812a09cf6a886ba6e464f3bbb367411904ae3790dfe28311b15ad", size = 217924, upload-time = "2025-10-15T15:13:23.39Z" }, + { url = "https://files.pythonhosted.org/packages/5f/04/642c1d8a448ae5ea1369eac8495740a79eb4e581a9fb0cbdce56bbf56da1/coverage-7.11.0-py3-none-any.whl", hash = "sha256:4b7589765348d78fb4e5fb6ea35d07564e387da2fc5efff62e0222971f155f68", size = 207761, upload-time = "2025-10-15T15:15:06.439Z" }, +] + +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + +[[package]] +name = "executing" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, +] + +[[package]] +name = "filelock" +version = "3.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, +] + +[[package]] +name = "fsspec" +version = "2025.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/de/e0/bab50af11c2d75c9c4a2a26a5254573c0bd97cea152254401510950486fa/fsspec-2025.9.0.tar.gz", hash = "sha256:19fd429483d25d28b65ec68f9f4adc16c17ea2c7c7bf54ec61360d478fb19c19", size = 304847, upload-time = "2025-09-02T19:10:49.215Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/71/70db47e4f6ce3e5c37a607355f80da8860a33226be640226ac52cb05ef2e/fsspec-2025.9.0-py3-none-any.whl", hash = "sha256:530dc2a2af60a414a832059574df4a6e10cce927f6f4a78209390fe38955cfb7", size = 199289, upload-time = "2025-09-02T19:10:47.708Z" }, +] + +[[package]] +name = "identify" +version = "2.6.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/e7/685de97986c916a6d93b3876139e00eef26ad5bbbd61925d670ae8013449/identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf", size = 99311, upload-time = "2025-10-02T17:43:40.631Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/1c/e5fd8f973d4f375adb21565739498e2e9a1e54c858a97b9a8ccfdc81da9b/identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757", size = 99183, upload-time = "2025-10-02T17:43:39.137Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + +[[package]] +name = "ipython" +version = "9.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "decorator" }, + { name = "ipython-pygments-lexers" }, + { name = "jedi" }, + { name = "matplotlib-inline" }, + { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "stack-data" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/34/29b18c62e39ee2f7a6a3bba7efd952729d8aadd45ca17efc34453b717665/ipython-9.6.0.tar.gz", hash = "sha256:5603d6d5d356378be5043e69441a072b50a5b33b4503428c77b04cb8ce7bc731", size = 4396932, upload-time = "2025-09-29T10:55:53.948Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/c5/d5e07995077e48220269c28a221e168c91123ad5ceee44d548f54a057fc0/ipython-9.6.0-py3-none-any.whl", hash = "sha256:5f77efafc886d2f023442479b8149e7d86547ad0a979e9da9f045d252f648196", size = 616170, upload-time = "2025-09-29T10:55:47.676Z" }, +] + +[[package]] +name = "ipython-pygments-lexers" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393, upload-time = "2025-01-17T11:24:34.505Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, +] + +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "joblib" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698", size = 80313, upload-time = "2025-09-27T18:37:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/72/147da192e38635ada20e0a2e1a51cf8823d2119ce8883f7053879c2199b5/markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e", size = 11615, upload-time = "2025-09-27T18:36:30.854Z" }, + { url = "https://files.pythonhosted.org/packages/9a/81/7e4e08678a1f98521201c3079f77db69fb552acd56067661f8c2f534a718/markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce", size = 12020, upload-time = "2025-09-27T18:36:31.971Z" }, + { url = "https://files.pythonhosted.org/packages/1e/2c/799f4742efc39633a1b54a92eec4082e4f815314869865d876824c257c1e/markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d", size = 24332, upload-time = "2025-09-27T18:36:32.813Z" }, + { url = "https://files.pythonhosted.org/packages/3c/2e/8d0c2ab90a8c1d9a24f0399058ab8519a3279d1bd4289511d74e909f060e/markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d", size = 22947, upload-time = "2025-09-27T18:36:33.86Z" }, + { url = "https://files.pythonhosted.org/packages/2c/54/887f3092a85238093a0b2154bd629c89444f395618842e8b0c41783898ea/markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a", size = 21962, upload-time = "2025-09-27T18:36:35.099Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/336b8c7b6f4a4d95e91119dc8521402461b74a485558d8f238a68312f11c/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b", size = 23760, upload-time = "2025-09-27T18:36:36.001Z" }, + { url = "https://files.pythonhosted.org/packages/32/43/67935f2b7e4982ffb50a4d169b724d74b62a3964bc1a9a527f5ac4f1ee2b/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f", size = 21529, upload-time = "2025-09-27T18:36:36.906Z" }, + { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, + { url = "https://files.pythonhosted.org/packages/2f/e1/78ee7a023dac597a5825441ebd17170785a9dab23de95d2c7508ade94e0e/markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d", size = 14540, upload-time = "2025-09-27T18:36:38.761Z" }, + { url = "https://files.pythonhosted.org/packages/aa/5b/bec5aa9bbbb2c946ca2733ef9c4ca91c91b6a24580193e891b5f7dbe8e1e/markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c", size = 15105, upload-time = "2025-09-27T18:36:39.701Z" }, + { url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" }, +] + +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159, upload-time = "2024-04-15T13:44:44.803Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + +[[package]] +name = "nemo-retriever-ocr" +version = "1.2.0.dev2" +source = { editable = "." } +dependencies = [ + { name = "pandas" }, + { name = "pillow" }, + { name = "scikit-learn" }, + { name = "shapely" }, + { name = "torch" }, + { name = "torchvision" }, +] + +[package.dev-dependencies] +dev = [ + { name = "ipython" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-cov" }, +] + +[package.metadata] +requires-dist = [ + { name = "pandas", specifier = ">=2.3.3" }, + { name = "pillow", specifier = ">=12.0.0" }, + { name = "scikit-learn", specifier = ">=1.7.2" }, + { name = "shapely", specifier = ">=2.1.2,<3" }, + { name = "torch", specifier = ">=2.9.0" }, + { name = "torchvision", specifier = ">=0.24.0" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "ipython", specifier = ">=9.6.0" }, + { name = "pre-commit", specifier = ">=3.8.0,<4" }, + { name = "pytest", specifier = ">=8.3.2,<9" }, + { name = "pytest-cov", specifier = ">=5.0.0,<6" }, +] + +[[package]] +name = "networkx" +version = "3.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, +] + +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, +] + +[[package]] +name = "numpy" +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/f4/098d2270d52b41f1bd7db9fc288aaa0400cb48c2a3e2af6fa365d9720947/numpy-2.3.4.tar.gz", hash = "sha256:a7d018bfedb375a8d979ac758b120ba846a7fe764911a64465fd87b8729f4a6a", size = 20582187, upload-time = "2025-10-15T16:18:11.77Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/7a/02420400b736f84317e759291b8edaeee9dc921f72b045475a9cbdb26b17/numpy-2.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ef1b5a3e808bc40827b5fa2c8196151a4c5abe110e1726949d7abddfe5c7ae11", size = 20957727, upload-time = "2025-10-15T16:15:44.9Z" }, + { url = "https://files.pythonhosted.org/packages/18/90/a014805d627aa5750f6f0e878172afb6454552da929144b3c07fcae1bb13/numpy-2.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c2f91f496a87235c6aaf6d3f3d89b17dba64996abadccb289f48456cff931ca9", size = 14187262, upload-time = "2025-10-15T16:15:47.761Z" }, + { url = "https://files.pythonhosted.org/packages/c7/e4/0a94b09abe89e500dc748e7515f21a13e30c5c3fe3396e6d4ac108c25fca/numpy-2.3.4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f77e5b3d3da652b474cc80a14084927a5e86a5eccf54ca8ca5cbd697bf7f2667", size = 5115992, upload-time = "2025-10-15T16:15:50.144Z" }, + { url = "https://files.pythonhosted.org/packages/88/dd/db77c75b055c6157cbd4f9c92c4458daef0dd9cbe6d8d2fe7f803cb64c37/numpy-2.3.4-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:8ab1c5f5ee40d6e01cbe96de5863e39b215a4d24e7d007cad56c7184fdf4aeef", size = 6648672, upload-time = "2025-10-15T16:15:52.442Z" }, + { url = "https://files.pythonhosted.org/packages/e1/e6/e31b0d713719610e406c0ea3ae0d90760465b086da8783e2fd835ad59027/numpy-2.3.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:77b84453f3adcb994ddbd0d1c5d11db2d6bda1a2b7fd5ac5bd4649d6f5dc682e", size = 14284156, upload-time = "2025-10-15T16:15:54.351Z" }, + { url = "https://files.pythonhosted.org/packages/f9/58/30a85127bfee6f108282107caf8e06a1f0cc997cb6b52cdee699276fcce4/numpy-2.3.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4121c5beb58a7f9e6dfdee612cb24f4df5cd4db6e8261d7f4d7450a997a65d6a", size = 16641271, upload-time = "2025-10-15T16:15:56.67Z" }, + { url = "https://files.pythonhosted.org/packages/06/f2/2e06a0f2adf23e3ae29283ad96959267938d0efd20a2e25353b70065bfec/numpy-2.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:65611ecbb00ac9846efe04db15cbe6186f562f6bb7e5e05f077e53a599225d16", size = 16059531, upload-time = "2025-10-15T16:15:59.412Z" }, + { url = "https://files.pythonhosted.org/packages/b0/e7/b106253c7c0d5dc352b9c8fab91afd76a93950998167fa3e5afe4ef3a18f/numpy-2.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dabc42f9c6577bcc13001b8810d300fe814b4cfbe8a92c873f269484594f9786", size = 18578983, upload-time = "2025-10-15T16:16:01.804Z" }, + { url = "https://files.pythonhosted.org/packages/73/e3/04ecc41e71462276ee867ccbef26a4448638eadecf1bc56772c9ed6d0255/numpy-2.3.4-cp312-cp312-win32.whl", hash = "sha256:a49d797192a8d950ca59ee2d0337a4d804f713bb5c3c50e8db26d49666e351dc", size = 6291380, upload-time = "2025-10-15T16:16:03.938Z" }, + { url = "https://files.pythonhosted.org/packages/3d/a8/566578b10d8d0e9955b1b6cd5db4e9d4592dd0026a941ff7994cedda030a/numpy-2.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:985f1e46358f06c2a09921e8921e2c98168ed4ae12ccd6e5e87a4f1857923f32", size = 12787999, upload-time = "2025-10-15T16:16:05.801Z" }, + { url = "https://files.pythonhosted.org/packages/58/22/9c903a957d0a8071b607f5b1bff0761d6e608b9a965945411f867d515db1/numpy-2.3.4-cp312-cp312-win_arm64.whl", hash = "sha256:4635239814149e06e2cb9db3dd584b2fa64316c96f10656983b8026a82e6e4db", size = 10197412, upload-time = "2025-10-15T16:16:07.854Z" }, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cusparse-cu12" }, + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu12" +version = "3.3.20" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d00f26d3f9b2e3c3065be895e3059d6479ea5c638a3f38c9fec49b1b9dd7c1e5", size = 124657145, upload-time = "2025-08-04T20:25:19.995Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pandas" +version = "2.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/01/d40b85317f86cf08d853a4f495195c73815fdf205eef3993821720274518/pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b", size = 4495223, upload-time = "2025-09-29T23:34:51.853Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/fb/231d89e8637c808b997d172b18e9d4a4bc7bf31296196c260526055d1ea0/pandas-2.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d21f6d74eb1725c2efaa71a2bfc661a0689579b58e9c0ca58a739ff0b002b53", size = 11597846, upload-time = "2025-09-29T23:19:48.856Z" }, + { url = "https://files.pythonhosted.org/packages/5c/bd/bf8064d9cfa214294356c2d6702b716d3cf3bb24be59287a6a21e24cae6b/pandas-2.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3fd2f887589c7aa868e02632612ba39acb0b8948faf5cc58f0850e165bd46f35", size = 10729618, upload-time = "2025-09-29T23:39:08.659Z" }, + { url = "https://files.pythonhosted.org/packages/57/56/cf2dbe1a3f5271370669475ead12ce77c61726ffd19a35546e31aa8edf4e/pandas-2.3.3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecaf1e12bdc03c86ad4a7ea848d66c685cb6851d807a26aa245ca3d2017a1908", size = 11737212, upload-time = "2025-09-29T23:19:59.765Z" }, + { url = "https://files.pythonhosted.org/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b3d11d2fda7eb164ef27ffc14b4fcab16a80e1ce67e9f57e19ec0afaf715ba89", size = 12362693, upload-time = "2025-09-29T23:20:14.098Z" }, + { url = "https://files.pythonhosted.org/packages/a6/de/8b1895b107277d52f2b42d3a6806e69cfef0d5cf1d0ba343470b9d8e0a04/pandas-2.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a68e15f780eddf2b07d242e17a04aa187a7ee12b40b930bfdd78070556550e98", size = 12771002, upload-time = "2025-09-29T23:20:26.76Z" }, + { url = "https://files.pythonhosted.org/packages/87/21/84072af3187a677c5893b170ba2c8fbe450a6ff911234916da889b698220/pandas-2.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:371a4ab48e950033bcf52b6527eccb564f52dc826c02afd9a1bc0ab731bba084", size = 13450971, upload-time = "2025-09-29T23:20:41.344Z" }, + { url = "https://files.pythonhosted.org/packages/86/41/585a168330ff063014880a80d744219dbf1dd7a1c706e75ab3425a987384/pandas-2.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:a16dcec078a01eeef8ee61bf64074b4e524a2a3f4b3be9326420cabe59c4778b", size = 10992722, upload-time = "2025-09-29T23:20:54.139Z" }, +] + +[[package]] +name = "parso" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/de/53e0bcf53d13e005bd8c92e7855142494f41171b34c2536b86187474184d/parso-0.8.5.tar.gz", hash = "sha256:034d7354a9a018bdce352f48b2a8a450f05e9d6ee85db84764e9b6bd96dafe5a", size = 401205, upload-time = "2025-08-23T15:15:28.028Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668, upload-time = "2025-08-23T15:15:25.663Z" }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, +] + +[[package]] +name = "pillow" +version = "12.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/b0/cace85a1b0c9775a9f8f5d5423c8261c858760e2466c79b2dd184638b056/pillow-12.0.0.tar.gz", hash = "sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353", size = 47008828, upload-time = "2025-10-15T18:24:14.008Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/90/4fcce2c22caf044e660a198d740e7fbc14395619e3cb1abad12192c0826c/pillow-12.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371", size = 5249377, upload-time = "2025-10-15T18:22:05.993Z" }, + { url = "https://files.pythonhosted.org/packages/fd/e0/ed960067543d080691d47d6938ebccbf3976a931c9567ab2fbfab983a5dd/pillow-12.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082", size = 4650343, upload-time = "2025-10-15T18:22:07.718Z" }, + { url = "https://files.pythonhosted.org/packages/e7/a1/f81fdeddcb99c044bf7d6faa47e12850f13cee0849537a7d27eeab5534d4/pillow-12.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f", size = 6232981, upload-time = "2025-10-15T18:22:09.287Z" }, + { url = "https://files.pythonhosted.org/packages/88/e1/9098d3ce341a8750b55b0e00c03f1630d6178f38ac191c81c97a3b047b44/pillow-12.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d", size = 8041399, upload-time = "2025-10-15T18:22:10.872Z" }, + { url = "https://files.pythonhosted.org/packages/a7/62/a22e8d3b602ae8cc01446d0c57a54e982737f44b6f2e1e019a925143771d/pillow-12.0.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953", size = 6347740, upload-time = "2025-10-15T18:22:12.769Z" }, + { url = "https://files.pythonhosted.org/packages/4f/87/424511bdcd02c8d7acf9f65caa09f291a519b16bd83c3fb3374b3d4ae951/pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8", size = 7040201, upload-time = "2025-10-15T18:22:14.813Z" }, + { url = "https://files.pythonhosted.org/packages/dc/4d/435c8ac688c54d11755aedfdd9f29c9eeddf68d150fe42d1d3dbd2365149/pillow-12.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79", size = 6462334, upload-time = "2025-10-15T18:22:16.375Z" }, + { url = "https://files.pythonhosted.org/packages/2b/f2/ad34167a8059a59b8ad10bc5c72d4d9b35acc6b7c0877af8ac885b5f2044/pillow-12.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba", size = 7134162, upload-time = "2025-10-15T18:22:17.996Z" }, + { url = "https://files.pythonhosted.org/packages/0c/b1/a7391df6adacf0a5c2cf6ac1cf1fcc1369e7d439d28f637a847f8803beb3/pillow-12.0.0-cp312-cp312-win32.whl", hash = "sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0", size = 6298769, upload-time = "2025-10-15T18:22:19.923Z" }, + { url = "https://files.pythonhosted.org/packages/a2/0b/d87733741526541c909bbf159e338dcace4f982daac6e5a8d6be225ca32d/pillow-12.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a", size = 7001107, upload-time = "2025-10-15T18:22:21.644Z" }, + { url = "https://files.pythonhosted.org/packages/bc/96/aaa61ce33cc98421fb6088af2a03be4157b1e7e0e87087c888e2370a7f45/pillow-12.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad", size = 2436012, upload-time = "2025-10-15T18:22:23.621Z" }, +] + +[[package]] +name = "platformdirs" +version = "4.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/61/33/9611380c2bdb1225fdef633e2a9610622310fed35ab11dac9620972ee088/platformdirs-4.5.0.tar.gz", hash = "sha256:70ddccdd7c99fc5942e9fc25636a8b34d04c24b335100223152c2803e4063312", size = 21632, upload-time = "2025-10-08T17:44:48.791Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pre-commit" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/10/97ee2fa54dff1e9da9badbc5e35d0bbaef0776271ea5907eccf64140f72f/pre_commit-3.8.0.tar.gz", hash = "sha256:8bb6494d4a20423842e198980c9ecf9f96607a07ea29549e180eef9ae80fe7af", size = 177815, upload-time = "2024-07-28T19:59:01.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/92/caae8c86e94681b42c246f0bca35c059a2f0529e5b92619f6aba4cf7e7b6/pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f", size = 204643, upload-time = "2024-07-28T19:58:59.335Z" }, +] + +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762, upload-time = "2020-12-28T15:15:30.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993, upload-time = "2020-12-28T15:15:28.35Z" }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752, upload-time = "2024-07-21T12:58:21.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "8.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, +] + +[[package]] +name = "pytest-cov" +version = "5.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/67/00efc8d11b630c56f15f4ad9c7f9223f1e5ec275aaae3fa9118c6a223ad2/pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857", size = 63042, upload-time = "2024-03-24T20:16:34.856Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/3a/af5b4fa5961d9a1e6237b530eb87dd04aea6eb83da09d2a4073d81b54ccf/pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652", size = 21990, upload-time = "2024-03-24T20:16:32.444Z" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, +] + +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, +] + +[[package]] +name = "scikit-learn" +version = "1.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136, upload-time = "2025-09-09T08:21:29.075Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/aa/3996e2196075689afb9fce0410ebdb4a09099d7964d061d7213700204409/scikit_learn-1.7.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8d91a97fa2b706943822398ab943cde71858a50245e31bc71dba62aab1d60a96", size = 9259818, upload-time = "2025-09-09T08:20:43.19Z" }, + { url = "https://files.pythonhosted.org/packages/43/5d/779320063e88af9c4a7c2cf463ff11c21ac9c8bd730c4a294b0000b666c9/scikit_learn-1.7.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:acbc0f5fd2edd3432a22c69bed78e837c70cf896cd7993d71d51ba6708507476", size = 8636997, upload-time = "2025-09-09T08:20:45.468Z" }, + { url = "https://files.pythonhosted.org/packages/5c/d0/0c577d9325b05594fdd33aa970bf53fb673f051a45496842caee13cfd7fe/scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e5bf3d930aee75a65478df91ac1225ff89cd28e9ac7bd1196853a9229b6adb0b", size = 9478381, upload-time = "2025-09-09T08:20:47.982Z" }, + { url = "https://files.pythonhosted.org/packages/82/70/8bf44b933837ba8494ca0fc9a9ab60f1c13b062ad0197f60a56e2fc4c43e/scikit_learn-1.7.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4d6e9deed1a47aca9fe2f267ab8e8fe82ee20b4526b2c0cd9e135cea10feb44", size = 9300296, upload-time = "2025-09-09T08:20:50.366Z" }, + { url = "https://files.pythonhosted.org/packages/c6/99/ed35197a158f1fdc2fe7c3680e9c70d0128f662e1fee4ed495f4b5e13db0/scikit_learn-1.7.2-cp312-cp312-win_amd64.whl", hash = "sha256:6088aa475f0785e01bcf8529f55280a3d7d298679f50c0bb70a2364a82d0b290", size = 8731256, upload-time = "2025-09-09T08:20:52.627Z" }, +] + +[[package]] +name = "scipy" +version = "1.16.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/3b/546a6f0bfe791bbb7f8d591613454d15097e53f906308ec6f7c1ce588e8e/scipy-1.16.2.tar.gz", hash = "sha256:af029b153d243a80afb6eabe40b0a07f8e35c9adc269c019f364ad747f826a6b", size = 30580599, upload-time = "2025-09-11T17:48:08.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/8d/6396e00db1282279a4ddd507c5f5e11f606812b608ee58517ce8abbf883f/scipy-1.16.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:89d6c100fa5c48472047632e06f0876b3c4931aac1f4291afc81a3644316bb0d", size = 36646259, upload-time = "2025-09-11T17:40:39.329Z" }, + { url = "https://files.pythonhosted.org/packages/3b/93/ea9edd7e193fceb8eef149804491890bde73fb169c896b61aa3e2d1e4e77/scipy-1.16.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ca748936cd579d3f01928b30a17dc474550b01272d8046e3e1ee593f23620371", size = 28888976, upload-time = "2025-09-11T17:40:46.82Z" }, + { url = "https://files.pythonhosted.org/packages/91/4d/281fddc3d80fd738ba86fd3aed9202331180b01e2c78eaae0642f22f7e83/scipy-1.16.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:fac4f8ce2ddb40e2e3d0f7ec36d2a1e7f92559a2471e59aec37bd8d9de01fec0", size = 20879905, upload-time = "2025-09-11T17:40:52.545Z" }, + { url = "https://files.pythonhosted.org/packages/69/40/b33b74c84606fd301b2915f0062e45733c6ff5708d121dd0deaa8871e2d0/scipy-1.16.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:033570f1dcefd79547a88e18bccacff025c8c647a330381064f561d43b821232", size = 23553066, upload-time = "2025-09-11T17:40:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/55/a7/22c739e2f21a42cc8f16bc76b47cff4ed54fbe0962832c589591c2abec34/scipy-1.16.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ea3421209bf00c8a5ef2227de496601087d8f638a2363ee09af059bd70976dc1", size = 33336407, upload-time = "2025-09-11T17:41:06.796Z" }, + { url = "https://files.pythonhosted.org/packages/53/11/a0160990b82999b45874dc60c0c183d3a3a969a563fffc476d5a9995c407/scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f66bd07ba6f84cd4a380b41d1bf3c59ea488b590a2ff96744845163309ee8e2f", size = 35673281, upload-time = "2025-09-11T17:41:15.055Z" }, + { url = "https://files.pythonhosted.org/packages/96/53/7ef48a4cfcf243c3d0f1643f5887c81f29fdf76911c4e49331828e19fc0a/scipy-1.16.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5e9feab931bd2aea4a23388c962df6468af3d808ddf2d40f94a81c5dc38f32ef", size = 36004222, upload-time = "2025-09-11T17:41:23.868Z" }, + { url = "https://files.pythonhosted.org/packages/49/7f/71a69e0afd460049d41c65c630c919c537815277dfea214031005f474d78/scipy-1.16.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:03dfc75e52f72cf23ec2ced468645321407faad8f0fe7b1f5b49264adbc29cb1", size = 38664586, upload-time = "2025-09-11T17:41:31.021Z" }, + { url = "https://files.pythonhosted.org/packages/34/95/20e02ca66fb495a95fba0642fd48e0c390d0ece9b9b14c6e931a60a12dea/scipy-1.16.2-cp312-cp312-win_amd64.whl", hash = "sha256:0ce54e07bbb394b417457409a64fd015be623f36e330ac49306433ffe04bc97e", size = 38550641, upload-time = "2025-09-11T17:41:36.61Z" }, + { url = "https://files.pythonhosted.org/packages/92/ad/13646b9beb0a95528ca46d52b7babafbe115017814a611f2065ee4e61d20/scipy-1.16.2-cp312-cp312-win_arm64.whl", hash = "sha256:2a8ffaa4ac0df81a0b94577b18ee079f13fecdb924df3328fc44a7dc5ac46851", size = 25456070, upload-time = "2025-09-11T17:41:41.3Z" }, +] + +[[package]] +name = "setuptools" +version = "80.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, +] + +[[package]] +name = "shapely" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/bc/0989043118a27cccb4e906a46b7565ce36ca7b57f5a18b78f4f1b0f72d9d/shapely-2.1.2.tar.gz", hash = "sha256:2ed4ecb28320a433db18a5bf029986aa8afcfd740745e78847e330d5d94922a9", size = 315489, upload-time = "2025-09-24T13:51:41.432Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/c0/f3b6453cf2dfa99adc0ba6675f9aaff9e526d2224cbd7ff9c1a879238693/shapely-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fe2533caae6a91a543dec62e8360fe86ffcdc42a7c55f9dfd0128a977a896b94", size = 1833550, upload-time = "2025-09-24T13:50:30.019Z" }, + { url = "https://files.pythonhosted.org/packages/86/07/59dee0bc4b913b7ab59ab1086225baca5b8f19865e6101db9ebb7243e132/shapely-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ba4d1333cc0bc94381d6d4308d2e4e008e0bd128bdcff5573199742ee3634359", size = 1643556, upload-time = "2025-09-24T13:50:32.291Z" }, + { url = "https://files.pythonhosted.org/packages/26/29/a5397e75b435b9895cd53e165083faed5d12fd9626eadec15a83a2411f0f/shapely-2.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0bd308103340030feef6c111d3eb98d50dc13feea33affc8a6f9fa549e9458a3", size = 2988308, upload-time = "2025-09-24T13:50:33.862Z" }, + { url = "https://files.pythonhosted.org/packages/b9/37/e781683abac55dde9771e086b790e554811a71ed0b2b8a1e789b7430dd44/shapely-2.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1e7d4d7ad262a48bb44277ca12c7c78cb1b0f56b32c10734ec9a1d30c0b0c54b", size = 3099844, upload-time = "2025-09-24T13:50:35.459Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f3/9876b64d4a5a321b9dc482c92bb6f061f2fa42131cba643c699f39317cb9/shapely-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e9eddfe513096a71896441a7c37db72da0687b34752c4e193577a145c71736fc", size = 3988842, upload-time = "2025-09-24T13:50:37.478Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a0/704c7292f7014c7e74ec84eddb7b109e1fbae74a16deae9c1504b1d15565/shapely-2.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:980c777c612514c0cf99bc8a9de6d286f5e186dcaf9091252fcd444e5638193d", size = 4152714, upload-time = "2025-09-24T13:50:39.9Z" }, + { url = "https://files.pythonhosted.org/packages/53/46/319c9dc788884ad0785242543cdffac0e6530e4d0deb6c4862bc4143dcf3/shapely-2.1.2-cp312-cp312-win32.whl", hash = "sha256:9111274b88e4d7b54a95218e243282709b330ef52b7b86bc6aaf4f805306f454", size = 1542745, upload-time = "2025-09-24T13:50:41.414Z" }, + { url = "https://files.pythonhosted.org/packages/ec/bf/cb6c1c505cb31e818e900b9312d514f381fbfa5c4363edfce0fcc4f8c1a4/shapely-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:743044b4cfb34f9a67205cee9279feaf60ba7d02e69febc2afc609047cb49179", size = 1722861, upload-time = "2025-09-24T13:50:43.35Z" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707, upload-time = "2023-09-30T13:58:05.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, +] + +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + +[[package]] +name = "threadpoolctl" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, +] + +[[package]] +name = "torch" +version = "2.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools" }, + { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d3/3985739f3b8e88675127bf70f82b3a48ae083e39cda56305dbd90398fec0/torch-2.9.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e5f7af1dc4c0a7c4a260c2534f41ddaf209714f7c89145e644c44712fbd6b642", size = 104107898, upload-time = "2025-10-15T15:46:20.883Z" }, + { url = "https://files.pythonhosted.org/packages/a5/4b/f4bb2e6c25d0272f798cd6d7a04ed315da76cec68c602d87040c7847287f/torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:01cff95ecd9a212ea2f141db28acccdceb6a4c54f64e6c51091146f5e2a772c6", size = 899738273, upload-time = "2025-10-15T15:50:04.188Z" }, + { url = "https://files.pythonhosted.org/packages/66/11/c1c5ba6691cda6279087c35bd626536e4fd29521fe740abf5008377a9a02/torch-2.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:4582b162f541651f0cb184d3e291c05c2f556c7117c64a9873e2ee158d40062b", size = 109280887, upload-time = "2025-10-15T15:46:26.228Z" }, + { url = "https://files.pythonhosted.org/packages/dd/5f/b85bd8c05312d71de9402bf5868d217c38827cfd09d8f8514e5be128a52b/torch-2.9.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:33f58e9a102a91259af289d50525c30323b5c9ae1d31322b6447c0814da68695", size = 74478983, upload-time = "2025-10-15T15:46:39.406Z" }, +] + +[[package]] +name = "torchvision" +version = "0.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pillow" }, + { name = "torch" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/ef/81e4e69e02e2c4650b30e8c11c8974f946682a30e0ab7e9803a831beff76/torchvision-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c61d40bcd2e2451e932902a702ad495ba1ec6f279e90b1e15cef2bb55dc911e2", size = 1891726, upload-time = "2025-10-15T15:51:16.977Z" }, + { url = "https://files.pythonhosted.org/packages/00/7b/e3809b3302caea9a12c13f3adebe4fef127188438e719fd6c8dc93db1da6/torchvision-0.24.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b0531d1483fc322d7da0d83be52f0df860a75114ab87dbeeb9de765feaeda843", size = 2419495, upload-time = "2025-10-15T15:51:11.885Z" }, + { url = "https://files.pythonhosted.org/packages/7e/e6/7324ead6793075a8c75c56abeed1236d1750de16a5613cfe2ddad164a92a/torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:26b9dd9c083f8e5f7ac827de6d5b88c615d9c582dc87666770fbdf16887e4c25", size = 8050480, upload-time = "2025-10-15T15:51:24.012Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ad/3c56fcd2a0d6e8afa80e115b5ade4302232ec99655220a51d05709819523/torchvision-0.24.0-cp312-cp312-win_amd64.whl", hash = "sha256:060b7c50ed4b3fb0316b08e2e31bfd874ec2f63ef5ae02f81e54341ca4e88703", size = 4292225, upload-time = "2025-10-15T15:51:27.699Z" }, +] + +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621, upload-time = "2024-04-19T11:11:49.746Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, +] + +[[package]] +name = "triton" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/3a/e991574f3102147b642e49637e0281e9bb7c4ba254edb2bab78247c85e01/triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9e71db82261c4ffa3921cd050cd5faa18322d2d405c30eb56084afaff3b0833", size = 170476535, upload-time = "2025-10-13T16:38:05.18Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + +[[package]] +name = "virtualenv" +version = "20.35.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a4/d5/b0ccd381d55c8f45d46f77df6ae59fbc23d19e901e2d523395598e5f4c93/virtualenv-20.35.3.tar.gz", hash = "sha256:4f1a845d131133bdff10590489610c98c168ff99dc75d6c96853801f7f67af44", size = 6002907, upload-time = "2025-10-10T21:23:33.178Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/73/d9a94da0e9d470a543c1b9d3ccbceb0f59455983088e727b8a1824ed90fb/virtualenv-20.35.3-py3-none-any.whl", hash = "sha256:63d106565078d8c8d0b206d48080f938a8b25361e19432d2c9db40d2899c810a", size = 5981061, upload-time = "2025-10-10T21:23:30.433Z" }, +] + +[[package]] +name = "wcwidth" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, +]