| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include <fstream> |
| #include <iostream> |
| #include <string> |
| #include <vector> |
|
|
| #include "compression/io.h" |
| #include "gemma/benchmark_helper.h" |
| #include "gemma/gemma.h" |
| #include "util/args.h" |
| #include "hwy/base.h" |
| #include "nlohmann/json.hpp" |
|
|
| using json = nlohmann::json; |
|
|
| namespace gcpp { |
|
|
| class PromptArgs : public ArgsBase<PromptArgs> { |
| public: |
| PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } |
|
|
| Path layers_output; |
| std::string prompt; |
|
|
| |
| const char* Validate() const { |
| if (prompt.empty()) return "Must specify --prompt"; |
| return nullptr; |
| } |
|
|
| template <class Visitor> |
| void ForEach(const Visitor& visitor) { |
| visitor(layers_output, "layers_output", Path(""), |
| "Path to store layers output", 2); |
| visitor(prompt, "prompt", std::string(""), "Prompt to the model", 2); |
| } |
| }; |
|
|
| int Run(int argc, char** argv) { |
| PromptArgs prompt_args(argc, argv); |
| AbortIfInvalidArgs(prompt_args); |
|
|
| json json_output; |
| GemmaEnv env(argc, argv); |
| env.MutableConfig().layers_output = |
| prompt_args.layers_output.Empty() |
| ? LayersOutputFunc() |
| : [&json_output](int pos, const std::string& key, const float* values, |
| size_t values_len) { |
| std::vector<float> v{values, values + values_len}; |
| json_output[std::to_string(pos)][key] = v; |
| }; |
|
|
| const auto [answer, token_count] = env.QueryModel(prompt_args.prompt); |
| std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush; |
|
|
| if (env.MutableConfig().layers_output) { |
| std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out); |
| if (!output_f) HWY_ABORT("Opening layer output file failed"); |
| output_f << json_output.dump(); |
| if (!output_f) HWY_ABORT("Writing to layer output file failed"); |
| output_f.close(); |
| } |
| return 0; |
| } |
|
|
| } |
|
|
| int main(int argc, char** argv) { return gcpp::Run(argc, argv); } |
|
|