DriveBert / bert /cpp_postprocess /bert_postprocess.cpp
jva96160's picture
Upload 64 files
67068c5 verified
#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <fstream>
#include <regex>
#include <memory>
#include "json.hpp" // nlohmann/json library
#include <sstream>
#include <cctype>
#include <cmath>
// using nlohmann/json
using json = nlohmann::json;
// Define Postprocessor class to encapsulate all logic
class Postprocessor {
private:
// All original global variables are now private members of the class
std::map<std::string, std::string> cls_map;
std::map<std::string, std::vector<std::string>> area2NL;
std::map<std::string, std::string> NL2area;
std::map<std::string, std::map<std::string, std::string>> maxmin_val_map;
std::set<std::string> uncommon_area_NL;
std::set<std::string> st_open;
std::set<std::string> st_close;
std::map<std::string, std::string> default_value;
std::vector<std::string> switch_keywords;
std::vector<std::string> level_keywords;
std::vector<std::string> str2remove;
std::vector<std::string> max_level_keys;
std::vector<std::string> min_level_keys;
std::string exclude_pattern_str;
/**
* @brief Utility function to split a UTF-8 string into individual characters (or bytes for ASCII).
* @param text The input UTF-8 string.
* @return A vector of strings, where each string is a single character.
*/
std::vector<std::string> split_utf8_string(const std::string& text) {
std::vector<std::string> utf8_chars;
size_t i = 0;
while (i < text.length()) {
size_t len = 1;
unsigned char c = (unsigned char)text[i];
if (c >= 0xE0) len = 3; // 3-byte character (most CJK)
else if (c >= 0xC0) len = 2; // 2-byte character
// Ensure we don't read past the end of the string
if (i + len > text.length()) {
len = 1; // Treat as a single byte if incomplete
}
utf8_chars.push_back(text.substr(i, len));
i += len;
}
return utf8_chars;
}
std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) {
size_t start_pos = 0;
while((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length(); // Handles case where 'to' is a substring of 'from'
}
return str;
}
/**
* @brief Converts English number words (like "one hundred twenty-three") into an Arabic number.
* This is a simplified implementation supporting numbers up to one billion.
* @param words The vector of English number words (tokens).
* @param start_index The starting index in the vector to begin parsing (always 0 here since it's a dedicated word sequence).
* @return A pair: the converted numerical value and the index of the next unprocessed word.
*/
std::pair<long long, size_t> parse_english_numbers(const std::vector<std::string>& words, size_t start_index) {
// Maps for English number words
static const std::map<std::string, long long> ONES = {
{"zero", 0}, {"one", 1}, {"two", 2}, {"three", 3}, {"four", 4},
{"five", 5}, {"six", 6}, {"seven", 7}, {"eight", 8}, {"nine", 9}
};
static const std::map<std::string, long long> TEENS = {
{"ten", 10}, {"eleven", 11}, {"twelve", 12}, {"thirteen", 13},
{"fourteen", 14}, {"fifteen", 15}, {"sixteen", 16},
{"seventeen", 17}, {"eighteen", 18}, {"nineteen", 19}
};
static const std::map<std::string, long long> TENS = {
{"twenty", 20}, {"thirty", 30}, {"forty", 40}, {"fifty", 50},
{"sixty", 60}, {"seventy", 70}, {"eighty", 80}, {"ninety", 90}
};
static const std::map<std::string, long long> MAGNITUDES = {
{"hundred", 100}, {"thousand", 1000}, {"million", 1000000}, {"billion", 1000000000}
};
long long result = 0;
long long current_magnitude_val = 0; // Value within the current 'thousand' segment
long long current_hundred_val = 0; // Value within the current 'hundred' segment
size_t i = start_index;
while (i < words.size()) {
std::string word = words[i];
// Convert to lowercase for case-insensitive matching
std::transform(word.begin(), word.end(), word.begin(),
[](unsigned char c){ return std::tolower(c); });
// Handle hyphenated numbers (e.g., twenty-one)
if (word.find('-') != std::string::npos) {
std::string part1 = word.substr(0, word.find('-'));
std::string part2 = word.substr(word.find('-') + 1);
if (TENS.count(part1) && ONES.count(part2)) {
current_hundred_val += TENS.at(part1) + ONES.at(part2);
i++;
continue;
}
}
if (ONES.count(word)) {
current_hundred_val += ONES.at(word);
} else if (TEENS.count(word)) {
current_hundred_val += TEENS.at(word);
} else if (TENS.count(word)) {
current_hundred_val += TENS.at(word);
} else if (MAGNITUDES.count(word)) {
long long mag = MAGNITUDES.at(word);
if (mag == 100) {
// If 'hundred' is encountered, multiply the current hundred value by 100.
if (current_hundred_val == 0) current_hundred_val = 1;
current_hundred_val *= mag;
} else {
// For 'thousand', 'million', etc.
current_magnitude_val += current_hundred_val;
if (current_magnitude_val == 0) current_magnitude_val = 1; // Simplification for "a thousand"
result += current_magnitude_val * mag;
current_magnitude_val = 0;
current_hundred_val = 0;
}
} else if (word == "and" && i + 1 < words.size()) {
// 'and' is often used to connect numbers, ignore it for calculation.
i++;
continue;
} else {
// Not a number word, conversion ends.
break;
}
i++;
}
// Final accumulation of the current segment
result += current_magnitude_val;
result += current_hundred_val;
// Check for 'zero' as a single word
if (start_index < words.size() && i == start_index + 1 && words[start_index] == "zero") {
return {0, start_index + 1};
}
// Only return a result if at least one number word was consumed
if (i == start_index) {
return {0, start_index}; // No number words found
}
return {result, i};
}
/**
* @brief Converts Chinese and English number words in a string to Arabic numerals.
* Chinese logic is preserved as requested. English logic is implemented separately.
* @param text The input string containing mixed Chinese/English text.
* @return The modified string with numbers converted to Arabic numerals.
*/
std::string cn2an_transform(const std::string& text) {
// --- Chinese Mappings (Preserved) ---
// Digit map for Chinese characters (e.g., "一" -> 1)
std::map<std::string, int> digit_map = {
{"零", 0}, {"一", 1}, {"二", 2}, {"三", 3}, {"四", 4}, {"五", 5}, {"六", 6}, {"七", 7}, {"八", 8}, {"九", 9},
{"〇", 0}, {"壹", 1}, {"貳", 2}, {"參", 3}, {"肆", 4}, {"伍", 5}, {"陸", 6}, {"柒", 7}, {"捌", 8}, {"玖", 9},
{"兩", 2}
};
// Unit map for Chinese magnitude characters (e.g., "十" -> 10)
std::map<std::string, int> unit_map = {
{"十", 10}, {"百", 100}, {"千", 1000}, {"萬", 10000}, {"億", 100000000}
};
// --- Tokenization and Conversion Prep ---
// Split the string into UTF-8 characters for Chinese processing
std::vector<std::string> utf8_chars = split_utf8_string(text);
std::string result_string = "";
// Variables for Chinese number accumulation
long long total_val = 0;
long long section_val = 0;
long long current_val = 0;
// State flags and storage for English number parsing
bool in_english_sequence = false;
std::vector<std::string> english_words;
// Helper lambda to flush and reset Chinese numbers
auto flush_chinese = [&](std::string& res_str) {
if (total_val > 0 || current_val > 0 || section_val > 0) {
res_str += std::to_string(total_val + current_val + section_val);
total_val = current_val = section_val = 0;
}
};
// Helper lambda to flush and reset English words
auto flush_english = [&](std::string& res_str) {
if (in_english_sequence && !english_words.empty()) {
auto [num, end_idx] = parse_english_numbers(english_words, 0);
if (end_idx > 0) {
// Number converted successfully
res_str += std::to_string(num);
// Append any remaining non-number words
for (size_t k = end_idx; k < english_words.size(); ++k) {
res_str += english_words[k];
}
} else {
// No number found, just append the words back
for (const auto& word : english_words) {
res_str += word;
}
}
english_words.clear();
in_english_sequence = false;
}
};
// --- Main Loop: Iterate through UTF-8 characters ---
for (size_t j = 0; j < utf8_chars.size(); ++j) {
std::string c = utf8_chars[j];
// 1. Check for Chinese characters
if (digit_map.count(c) || unit_map.count(c)) {
flush_english(result_string); // Chinese interrupts English sequence
// Apply preserved Chinese logic
if (digit_map.count(c)) {
// Chinese digit found
section_val = digit_map.at(c);
// If it's the last character or the next char is not a unit
if (j == utf8_chars.size() - 1 || !unit_map.count(utf8_chars[j+1])) {
current_val += section_val;
section_val = 0;
}
} else if (unit_map.count(c)) {
// Chinese unit found
int unit_val = unit_map.at(c);
// Handle cases like "十" (ten) where the leading "一" is implied (simplified)
if (unit_val < 10000 && (j == 0 || !digit_map.count(utf8_chars[j-1])) && section_val == 0 && current_val == 0) {
section_val = 1;
}
if (unit_val < 10000) {
// Ten, Hundred, Thousand units
current_val += section_val * unit_val;
} else {
// Ten Thousand, Hundred Million units (magnitude change)
total_val += (current_val + section_val) * unit_val;
current_val = 0;
}
section_val = 0;
}
}
// 2. Check for ASCII characters (potential English word or Arabic numeral)
else if (c.length() == 1) {
char ch = c[0];
// Check for existing Arabic numerals
if (std::isdigit(ch)) {
flush_chinese(result_string); // Arabic numeral interrupts Chinese number conversion
flush_english(result_string); // Arabic numeral interrupts English word sequence
// Collect contiguous Arabic numerals
std::string num_str = "";
while (j < utf8_chars.size() && utf8_chars[j].length() == 1 && std::isdigit(utf8_chars[j][0])) {
num_str += utf8_chars[j];
j++;
}
result_string += num_str;
j--; // Decrement to re-check the character after the number sequence
continue;
}
// Check for English words (alphabetic characters)
if (std::isalpha(ch) || ch == '-') {
flush_chinese(result_string); // English word interrupts Chinese number conversion
// Collect contiguous alphabetic/hyphen characters as one token
std::string token = "";
while (j < utf8_chars.size() && utf8_chars[j].length() == 1 && (std::isalpha(utf8_chars[j][0]) || utf8_chars[j][0] == '-')) {
token += utf8_chars[j];
j++;
}
j--; // Stay on the last character of the token for the next loop iteration
// Add the token to the list of English words
english_words.push_back(token);
in_english_sequence = true;
} else {
// Not a Chinese char, Arabic digit, or English word part (e.g., space, punctuation)
flush_chinese(result_string); // Flush pending Chinese number
flush_english(result_string); // Flush pending English words
// Append the non-number character (space, punctuation, etc.)
result_string += c;
}
}
// 3. Any other non-Chinese, non-ASCII characters (e.g., symbols, other scripts)
else {
flush_chinese(result_string);
flush_english(result_string);
// Append the character
result_string += c;
}
} // End of loop
// --- Final Flush ---
flush_chinese(result_string);
flush_english(result_string);
return result_string;
}
// A utility function to extract numbers from a string.
std::vector<std::string> extract_numbers_from_string(const std::string& text) {
// Use the regex string read from the JSON file
std::regex exclude_pattern(this->exclude_pattern_str);
std::string cleaned_text = std::regex_replace(text, exclude_pattern, "");
std::regex num_pattern(R"(\d+)");
std::sregex_iterator next(cleaned_text.begin(), cleaned_text.end(), num_pattern);
std::sregex_iterator end;
std::vector<std::string> numbers;
while (next != end) {
numbers.push_back(next->str());
++next;
}
return numbers;
}
// A utility function to get keywords from a string.
json get_keywords(const std::string& text) {
json keywords;
keywords["switch"] = json::array();
keywords["level"] = json::array();
keywords["num"] = json::array();
keywords["area_id"] = json::array();
for (const auto& pair : area2NL) {
for (const auto& nl_phrase : pair.second) {
if (text.find(nl_phrase) != std::string::npos) {
keywords["area_id"].push_back(nl_phrase);
}
}
}
auto find_keywords = [&](const std::vector<std::string>& kw_list, const std::string& key) {
for (const auto& keyword : kw_list) {
if (text.find(keyword) != std::string::npos) {
keywords[key].push_back(keyword);
}
}
};
find_keywords(switch_keywords, "switch");
find_keywords(level_keywords, "level");
keywords["num"] = extract_numbers_from_string(text);
return keywords;
}
// A utility function to find the longest string in a JSON array.
std::string find_longest_string(const json& arr) {
if (arr.empty()) {
return "";
}
std::string longest_str = arr[0].get<std::string>();
for (const auto& val : arr) {
std::string current_str = val.get<std::string>();
if (current_str.length() > longest_str.length()) {
longest_str = current_str;
}
}
return longest_str;
}
public:
// This is the main public method of the Postprocessor
void load_data() {
// Load response_template.json
std::ifstream cls_file("response_template.json");
if (!cls_file.is_open()) {
std::cerr << "[ERROR] response_template.json file not found. Please ensure it's in the same directory." << std::endl;
return;
}
json temp_json;
try {
cls_file >> temp_json;
for (json::iterator it = temp_json.begin(); it != temp_json.end(); ++it) {
cls_map[it.key()] = it.value().get<std::string>();
}
} catch (const json::parse_error& e) {
std::cerr << "[ERROR] JSON parse error in response_template.json: " << e.what() << std::endl;
}
// Load keywords_data.json
std::ifstream data_file("keywords_data.json");
if (!data_file.is_open()) {
std::cerr << "[ERROR] keywords_data.json file not found. Please ensure it's in the same directory." << std::endl;
return;
}
json data_json;
try {
data_file >> data_json;
// Read and populate area2NL
for (const auto& [key, value] : data_json["area2NL"].items()) {
area2NL[key] = value.get<std::vector<std::string>>();
}
// Read and populate maxmin_val_map
for (const auto& [key, value] : data_json["maxmin_val_map"].items()) {
maxmin_val_map[key] = value.get<std::map<std::string, std::string>>();
}
// Read and populate uncommon_area_NL
for (const auto& val : data_json["uncommon_area_NL"]) {
uncommon_area_NL.insert(val.get<std::string>());
}
// Read and populate st_open
for (const auto& val : data_json["st_open"]) {
st_open.insert(val.get<std::string>());
}
// Read and populate st_close
for (const auto& val : data_json["st_close"]) {
st_close.insert(val.get<std::string>());
}
// Read and populate default_value
for (const auto& [key, value] : data_json["default_value"].items()) {
default_value[key] = value.get<std::string>();
}
// Read and populate switch_keywords
switch_keywords = data_json["switch_keywords"].get<std::vector<std::string>>();
// Read and populate level_keywords
level_keywords = data_json["level_keywords"].get<std::vector<std::string>>();
// Read and populate exclude_pattern_str
exclude_pattern_str = data_json["exclude_pattern_str"].get<std::string>();
str2remove = data_json["str2remove"].get<std::vector<std::string>>();
max_level_keys = data_json["max_level"].get<std::vector<std::string>>();
min_level_keys = data_json["min_level"].get<std::vector<std::string>>();
// Convert and populate NL2area based on area2NL
for (const auto& [area_id, nl_phrases] : area2NL) {
for (const auto& nl_phrase : nl_phrases) {
NL2area[nl_phrase] = area_id;
}
}
} catch (const json::parse_error& e) {
std::cerr << "[ERROR] JSON parse error in keywords_data.json: " << e.what() << std::endl;
}
}
// Set verbose variable to public
bool verbose = false;
// This is the main public method of the Postprocessor
std::string postprocess(const std::string& query, const std::string& pred_class) {
if (verbose) {
std::cerr << "[DEBUG] input query: " << query << ", pred_class: " << pred_class << std::endl;
}
std::string ori_func_name = pred_class.substr(0, pred_class.find('%'));
if (cls_map.find(pred_class) == cls_map.end()) {
// std::cerr << "[ERROR] Key not found in cls_map: " << pred_class << std::endl;
json empty_json;
return empty_json.dump();
}
std::string func_tmp_str = cls_map[pred_class];
std::string ori_query = query;
std::string new_query = cn2an_transform(query);
// query transform
std::transform(new_query.begin(), new_query.end(), new_query.begin(),
[](unsigned char c){ return std::tolower(c); });
for (auto s:str2remove){
size_t pos = new_query.find(s);
if (pos != std::string::npos) {
new_query.erase(pos, s.length());
}
}
std::replace(new_query.begin(), new_query.end(), '-', ' ');
new_query = ReplaceAll(new_query, std::string("first level"), std::string("1"));
new_query = ReplaceAll(new_query, std::string("second level"), std::string("2"));
new_query = ReplaceAll(new_query, std::string("third level"), std::string("3"));
new_query = ReplaceAll(new_query, std::string("fourth level"), std::string("4"));
new_query = ReplaceAll(new_query, std::string("fifth level"), std::string("5"));
new_query = ReplaceAll(new_query, std::string("sixth level"), std::string("6"));
new_query = ReplaceAll(new_query, std::string("seventh level"), std::string("7"));
new_query = ReplaceAll(new_query, std::string("eighth level"), std::string("8"));
if (verbose) {
std::cerr << "[DEBUG] new_query: " << new_query << std::endl;
}
json keywords = get_keywords(new_query);
if (verbose) {
std::cerr << "[DEBUG] keywords: " << keywords.dump() << std::endl;
}
json func_tmp = json::parse(func_tmp_str);
std::set<std::string> set_keywords;
for (const auto& kw : keywords["area_id"]) {
set_keywords.insert(kw.get<std::string>());
}
if (!keywords["area_id"].empty()) {
std::string area_kw = find_longest_string(keywords["area_id"]);
if (verbose) {
std::cerr << "[DEBUG] area_kw: " << area_kw << std::endl;
}
if (pred_class.find("SLIDING_DOOR") != std::string::npos) {
if (set_keywords.count("左邊") || set_keywords.count("左側") || set_keywords.count("左") || set_keywords.count("left")) {
func_tmp["areaId"] = NL2area["左邊"];
} else if (set_keywords.count("右邊") || set_keywords.count("右側") || set_keywords.count("右") || set_keywords.count("right")) {
func_tmp["areaId"] = NL2area["右邊"];
}
} else if (pred_class.find("HVAC_DEFROSTER") != std::string::npos) {
if (set_keywords.count("前除霜") || set_keywords.count("front defroster")) {
func_tmp["areaId"] = NL2area["前除霜"];
} else if (set_keywords.count("後除霜") || set_keywords.count("rear defroster")) {
func_tmp["areaId"] = NL2area["後除霜"];
}
} else if (pred_class.find("POWER_SUNSHADE") != std::string::npos) {
if (set_keywords.count("頂棚") || set_keywords.count("roof")) {
func_tmp["areaId"] = NL2area["頂棚"];
} else if (set_keywords.count("右邊") || set_keywords.count("右側") || set_keywords.count("右") || set_keywords.count("right")) {
func_tmp["areaId"] = "SEAT_ROW_2_RIGHT";
} else if (set_keywords.count("左邊") || set_keywords.count("左側") || set_keywords.count("左") || set_keywords.count("left")) {
func_tmp["areaId"] = "SEAT_ROW_2_LEFT";
}
}
bool is_uncommon = false;
for (const auto& uc_kw : uncommon_area_NL) {
if (area_kw == uc_kw) {
is_uncommon = true;
break;
}
}
if (!is_uncommon) {
if (NL2area.count(area_kw) && func_tmp["areaId"]=="") {
func_tmp["areaId"] = NL2area[area_kw];
}
}
if (func_tmp["areaId"]=="SEAT_ROW_1" || func_tmp["areaId"]=="SEAT_ROW_2" || func_tmp["areaId"]=="SEAT_ROW_3"){
if (new_query.find("left") != std::string::npos){
func_tmp["areaId"] = std::string(func_tmp["areaId"])+"_LEFT";
}
else if (new_query.find("right") != std::string::npos){
func_tmp["areaId"] = std::string(func_tmp["areaId"])+"_RIGHT";
}
}
}
if (!keywords["num"].empty() && func_tmp.count("value") && func_tmp["value"].is_string() && func_tmp["value"].get<std::string>() == "") {
func_tmp["value"] = keywords["num"][0].get<std::string>();
} else if (!keywords["level"].empty() && func_tmp.count("value") && func_tmp["value"].is_string() && func_tmp["value"].get<std::string>() == "" && maxmin_val_map.count(ori_func_name)) {
for (auto &m:max_level_keys){
if(std::find(keywords["level"].begin(), keywords["level"].end(), m) != keywords["level"].end())
func_tmp["value"] = maxmin_val_map[ori_func_name]["max"];
}
for (auto &m:min_level_keys) {
if(std::find(keywords["level"].begin(), keywords["level"].end(), m) != keywords["level"].end())
func_tmp["value"] = maxmin_val_map[ori_func_name]["min"];
}
} else if (!keywords["switch"].empty() && ori_func_name == "POWER_SUNSHADE" && func_tmp.count("value") && func_tmp["value"].is_string() && func_tmp["value"].get<std::string>() == "") {
bool open_found = false;
bool close_found = false;
for (const auto& op : keywords["switch"]) {
if (st_open.count(op)) open_found = true;
if (st_close.count(op)) close_found = true;
}
if (open_found && new_query.find("開大") == std::string::npos) {
func_tmp["value"] = "100";
} else if (close_found && new_query.find("關小") == std::string::npos) {
func_tmp["value"] = "0";
}
}
if (func_tmp.count("value") && func_tmp["value"].is_string() && func_tmp["value"].get<std::string>() == "") {
if (pred_class.find("increas") != std::string::npos || pred_class.find("decreas") != std::string::npos || pred_class.find("reduc") != std::string::npos || pred_class.find("reduc") != std::string::npos) {
func_tmp["value"] = "1";
} else if (ori_func_name == "HVAC_TEMPERATURE_SET") {
if (new_query.find("熱") != std::string::npos || new_query.find("hot") != std::string::npos) {
func_tmp["value"] = maxmin_val_map[ori_func_name]["min"];
} else if (new_query.find("冷") != std::string::npos || new_query.find("凍") != std::string::npos || new_query.find("cold") != std::string::npos) {
func_tmp["value"] = maxmin_val_map[ori_func_name]["max"];
}
} else if (default_value.count(ori_func_name)) {
func_tmp["value"] = default_value[ori_func_name];
}
}
json final_json;
if (ori_func_name == "set_seat_mode") {
final_json.push_back({{"name", ori_func_name}, {"arguments", func_tmp}});
} else if (ori_func_name == "get_hhtd_info" || ori_func_name == "get_vehicle_info") {
func_tmp["query"] = ori_query;
final_json.push_back({{"name", ori_func_name}, {"arguments", func_tmp}});
} else {
final_json.push_back({{"name", "control_car_properties"}, {"arguments", func_tmp}});
}
std::string result = final_json.dump(4);
if (verbose) {
std::cerr << "[DEBUG] Final output: " << result << std::endl;
}
return result;
}
};
// C-style wrapper function for ctypes calls
extern "C" {
const char* postprocess_c(const char* query_c, const char* pred_class_c) {
static Postprocessor processor;
static bool data_loaded = false;
// Load data only on the first call
if (!data_loaded) {
processor.load_data();
data_loaded = true;
}
std::string query(query_c);
std::string pred_class(pred_class_c);
std::string result = processor.postprocess(query, pred_class);
char* c_str = new char[result.length() + 1];
if (c_str == nullptr) {
std::cerr << "ERROR: Memory allocation failed in postprocess_c." << std::endl;
return nullptr;
}
std::copy(result.begin(), result.end(), c_str);
c_str[result.length()] = '\0';
return c_str;
}
}