| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_ |
| #define THIRD_PARTY_GEMMA_CPP_UTIL_ARGS_H_ |
|
|
| #include <stdio.h> |
|
|
| #include <algorithm> |
| #include <string> |
|
|
| #include "compression/io.h" |
| #include "hwy/base.h" |
|
|
| namespace gcpp { |
|
|
| |
| |
| |
| |
| template <class Args> |
| class ArgsBase { |
| struct InitVisitor { |
| template <typename T> |
| void operator()(T& t, const char* , const T& init, |
| const char* , int = 0) const { |
| t = init; |
| } |
| }; |
|
|
| struct HelpVisitor { |
| template <typename T> |
| void operator()(T&, const char* name, T , const char* help, |
| int = 0) const { |
| fprintf(stderr, " --%s : %s\n", name, help); |
| } |
| }; |
|
|
| class PrintVisitor { |
| public: |
| explicit PrintVisitor(int verbosity) : verbosity_(verbosity) {} |
|
|
| template <typename T> |
| void operator()(const T& t, const char* name, const T& , |
| const char* , int print_verbosity = 0) const { |
| if (verbosity_ >= print_verbosity) { |
| fprintf(stderr, "%-30s: %s\n", name, std::to_string(t).c_str()); |
| } |
| } |
|
|
| void operator()(const std::string& t, const char* name, |
| const std::string& , const char* , |
| int print_verbosity = 0) const { |
| if (verbosity_ >= print_verbosity) { |
| fprintf(stderr, "%-30s: %s\n", name, t.c_str()); |
| } |
| } |
| void operator()(const Path& t, const char* name, const Path& , |
| const char* , int print_verbosity = 0) const { |
| if (verbosity_ >= print_verbosity) { |
| fprintf(stderr, "%-30s: %s\n", name, t.Shortened().c_str()); |
| } |
| } |
|
|
| private: |
| int verbosity_; |
| }; |
|
|
| |
| |
| |
| class ParseVisitor { |
| public: |
| ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {} |
|
|
| template <typename T> |
| void operator()(T& t, const char* name, const T& , |
| const char* , int = 0) const { |
| const std::string prefixed = std::string("--") + name; |
| for (int i = 1; i < argc_; ++i) { |
| if (std::string(argv_[i]) == prefixed) { |
| if (i + 1 >= argc_) { |
| HWY_ABORT("Missing value for %s\n", name); |
| } |
| if (!SetValue(argv_[i + 1], t)) { |
| HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]); |
| } |
| return; |
| } |
| } |
| } |
|
|
| private: |
| |
| template <typename T, HWY_IF_NOT_FLOAT(T)> |
| static bool SetValue(const char* string, T& t) { |
| t = std::stoi(string); |
| return true; |
| } |
|
|
| template <typename T, HWY_IF_FLOAT(T)> |
| static bool SetValue(const char* string, T& t) { |
| t = std::stof(string); |
| return true; |
| } |
|
|
| static bool SetValue(const char* string, std::string& t) { |
| t = string; |
| return true; |
| } |
| static bool SetValue(const char* string, Path& t) { |
| t.path = string; |
| return true; |
| } |
|
|
| static bool SetValue(const char* string, bool& t) { |
| std::string value(string); |
| |
| std::transform(value.begin(), value.end(), value.begin(), [](char c) { |
| return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c; |
| }); |
|
|
| if (value == "true" || value == "on" || value == "1") { |
| t = true; |
| return true; |
| } else if (value == "false" || value == "off" || value == "0") { |
| t = false; |
| return true; |
| } else { |
| return false; |
| } |
| } |
|
|
| int argc_; |
| char** argv_; |
| }; |
|
|
| template <class Visitor> |
| void ForEach(Visitor& visitor) { |
| static_cast<Args*>(this)->ForEach(visitor); |
| } |
|
|
| public: |
| |
| void Init() { |
| InitVisitor visitor; |
| ForEach(visitor); |
| } |
|
|
| void Help() { |
| HelpVisitor visitor; |
| ForEach(visitor); |
| } |
|
|
| void Print(int verbosity = 0) { |
| PrintVisitor visitor(verbosity); |
| ForEach(visitor); |
| } |
|
|
| void Parse(int argc, char* argv[]) { |
| ParseVisitor visitor(argc, argv); |
| ForEach(visitor); |
| } |
|
|
| |
| void InitAndParse(int argc, char* argv[]) { |
| Init(); |
| Parse(argc, argv); |
| } |
| }; |
|
|
| static inline HWY_MAYBE_UNUSED bool HasHelp(int argc, char* argv[]) { |
| |
| if (argc == 1) { |
| |
| return true; |
| } |
| for (int i = 1; i < argc; ++i) { |
| if (std::string(argv[i]) == "--help") { |
| return true; |
| } |
| } |
| return false; |
| } |
|
|
| template <class TArgs> |
| static inline HWY_MAYBE_UNUSED void AbortIfInvalidArgs(TArgs& args) { |
| if (const char* err = args.Validate()) { |
| args.Help(); |
| HWY_ABORT("Problem with args: %s\n", err); |
| } |
| } |
|
|
| } |
|
|
| #endif |
|
|