#include "utils.h" #include #include #include #include #include #include #include void utreplace(std::string & str, const std::string & needle, const std::string & replacement) { size_t pos = 0; while ((pos = str.find(needle, pos)) != std::string::npos) { str.replace(pos, needle.length(), replacement); pos += replacement.length(); } } std::map json_parse(const std::string & fname) { std::map result; // read file into string std::string json; { std::ifstream ifs(fname); if (!ifs) { fprintf(stderr, "Failed to open %s\n", fname.c_str()); exit(1); } json = std::string((std::istreambuf_iterator(ifs)), (std::istreambuf_iterator())); } if (json[0] != '{') { return result; } // parse json { bool has_key = false; bool in_token = false; std::string str_key = ""; std::string str_val = ""; int n = json.size(); for (int i = 1; i < n; ++i) { if (!in_token) { if (json[i] == ' ') continue; if (json[i] == '"') { in_token = true; continue; } } else { if (json[i] == '\\' && i+1 < n) { if (has_key == false) { str_key += json[i]; } else { str_val += json[i]; } ++i; } else if (json[i] == '"') { if (has_key == false) { has_key = true; ++i; while (json[i] == ' ') ++i; ++i; // : while (json[i] == ' ') ++i; if (json[i] != '\"') { while (json[i] != ',' && json[i] != '}') { str_val += json[i++]; } has_key = false; } else { in_token = true; continue; } } else { has_key = false; } ::utreplace(str_key, "\\u0120", " " ); // \u0120 -> space ::utreplace(str_key, "\\u010a", "\n"); // \u010a -> new line ::utreplace(str_key, "\\\"", "\""); // \\\" -> " try { result[str_key] = std::stoi(str_val); } catch (...) { //fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str()); } str_key = ""; str_val = ""; in_token = false; continue; } if (has_key == false) { str_key += json[i]; } else { str_val += json[i]; } } } } return result; } void gpt_vocab::add_special_token(const std::string & token) { special_tokens.push_back(token); } std::string convert_to_utf8(const std::wstring & input) { std::wstring_convert> converter; return converter.to_bytes(input); } std::wstring convert_to_wstring(const std::string & input) { std::wstring_convert> converter; return converter.from_bytes(input); } std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { std::vector words; // first split the text into words { std::string str = text; std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; // Generate the subpattern from the special_tokens vector if it's not empty if (!vocab.special_tokens.empty()) { std::string special_tokens_subpattern; for (const auto & token : vocab.special_tokens) { if (!special_tokens_subpattern.empty()) { special_tokens_subpattern += "|"; } special_tokens_subpattern += token; } // Modify the regex pattern with the generated special tokens subpattern pat = special_tokens_subpattern + "|" + pat; } std::regex re(pat); std::smatch m; while (std::regex_search(str, m, re)) { for (auto x : m) { words.push_back(x); } str = m.suffix(); } } // find the longest token that forms each word in words: std::vector tokens; for (const auto & word : words) { for (int i = 0; i < word.size(); ){ for (int j = word.size() - 1; j >= i; j--){ auto cand = word.substr(i, j-i+1); auto it = vocab.token_to_id.find(cand); if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab tokens.push_back(it->second); i = j + 1; break; } else if (j == i){ // word.substr(i, 1) has no matching fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data()); i++; } } } } return tokens; } bool should_transpose_layer(std::string name) { if(name.find(".mlp.fc_in.weight")!=std::string::npos || name.find(".attn.out_proj.weight")!=std::string::npos || name.find(".attn.q_proj.weight")!=std::string::npos || name.find(".attn.k_proj.weight")!=std::string::npos || name.find(".attn.v_proj.weight")!=std::string::npos || name.find("/attn/c_attn/w")!=std::string::npos || name.find("/attn/c_proj/w")!=std::string::npos || name.find("/mlp/c_fc/w")!=std::string::npos || name.find("/mlp/c_proj/w")!=std::string::npos) { return true; } return false; }