| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include <math.h> |
|
|
| #include <pocketsphinx/prim_type.h> |
| #include <pocketsphinx/err.h> |
|
|
| #include "util/ckd_alloc.h" |
| #include "util/byteorder.h" |
|
|
| #include "ngram_model_internal.h" |
| #include "lm_trie_quant.h" |
|
|
| |
| #define FLOAT_INF (0x7f800000) |
|
|
| typedef struct bins_s { |
| float32 *begin; |
| const float32 *end; |
| } bins_t; |
|
|
| struct lm_trie_quant_s { |
| bins_t tables[NGRAM_MAX_ORDER - 1][2]; |
| bins_t *longest; |
| float32 *values; |
| size_t nvalues; |
| uint8 prob_bits; |
| uint8 bo_bits; |
| uint32 prob_mask; |
| uint32 bo_mask; |
| }; |
|
|
| static void |
| bins_create(bins_t * bins, uint8 bits, float32 *begin) |
| { |
| bins->begin = begin; |
| bins->end = bins->begin + (1ULL << bits); |
| } |
|
|
| static float32 * |
| lower_bound(float32 *first, const float32 *last, float32 val) |
| { |
| int count, step; |
| float32 *it; |
|
|
| count = last - first; |
| while (count > 0) { |
| it = first; |
| step = count / 2; |
| it += step; |
| if (*it < val) { |
| first = ++it; |
| count -= step + 1; |
| } |
| else { |
| count = step; |
| } |
| } |
| return first; |
| } |
|
|
| static uint64 |
| bins_encode(bins_t * bins, float32 value) |
| { |
| float32 *above = lower_bound(bins->begin, bins->end, value); |
| if (above == bins->begin) |
| return 0; |
| if (above == bins->end) |
| return bins->end - bins->begin - 1; |
| return above - bins->begin - (value - *(above - 1) < *above - value); |
| } |
|
|
| static float32 |
| bins_decode(bins_t * bins, size_t off) |
| { |
| return bins->begin[off]; |
| } |
|
|
| static size_t |
| quant_size(int order) |
| { |
| int prob_bits = 16; |
| int bo_bits = 16; |
| size_t longest_table = (1U << prob_bits); |
| size_t middle_table = (1U << bo_bits) + longest_table; |
| |
| return (order - 2) * middle_table + longest_table; |
| } |
|
|
| lm_trie_quant_t * |
| lm_trie_quant_create(int order) |
| { |
| float32 *start; |
| int i; |
| lm_trie_quant_t *quant = |
| (lm_trie_quant_t *) ckd_calloc(1, sizeof(*quant)); |
| quant->nvalues = quant_size(order); |
| quant->values = |
| (float32 *) ckd_calloc(quant->nvalues, sizeof(*quant->values)); |
|
|
| quant->prob_bits = 16; |
| quant->bo_bits = 16; |
| quant->prob_mask = (1U << quant->prob_bits) - 1; |
| quant->bo_mask = (1U << quant->bo_bits) - 1; |
|
|
| start = (float32 *) (quant->values); |
| for (i = 0; i < order - 2; i++) { |
| bins_create(&quant->tables[i][0], quant->prob_bits, start); |
| start += (1ULL << quant->prob_bits); |
| bins_create(&quant->tables[i][1], quant->bo_bits, start); |
| start += (1ULL << quant->bo_bits); |
| } |
| bins_create(&quant->tables[order - 2][0], quant->prob_bits, start); |
| quant->longest = &quant->tables[order - 2][0]; |
| return quant; |
| } |
|
|
|
|
| lm_trie_quant_t * |
| lm_trie_quant_read_bin(FILE * fp, int order) |
| { |
| int dummy; |
| lm_trie_quant_t *quant; |
|
|
| fread(&dummy, sizeof(dummy), 1, fp); |
| quant = lm_trie_quant_create(order); |
| if (fread(quant->values, sizeof(*quant->values), |
| quant->nvalues, fp) != quant->nvalues) { |
| E_ERROR("Failed to read %d quantization values\n", |
| quant->nvalues); |
| lm_trie_quant_free(quant); |
| return NULL; |
| } |
| if (SWAP_LM_TRIE) { |
| size_t i; |
| for (i = 0; i < quant->nvalues; ++i) |
| SWAP_FLOAT32(&quant->values[i]); |
| } |
|
|
| return quant; |
| } |
|
|
| void |
| lm_trie_quant_write_bin(lm_trie_quant_t * quant, FILE * fp) |
| { |
| |
| int dummy = 1; |
| fwrite(&dummy, sizeof(dummy), 1, fp); |
| if (SWAP_LM_TRIE) { |
| size_t i; |
| for (i = 0; i < quant->nvalues; ++i) { |
| float32 value = quant->values[i]; |
| SWAP_FLOAT32(&value); |
| if (fwrite(&value, sizeof(value), 1, fp) != 1) { |
| E_ERROR("Failed to write quantization value\n"); |
| return; |
| } |
| } |
| } |
| else { |
| if (fwrite(quant->values, sizeof(*quant->values), |
| quant->nvalues, fp) != quant->nvalues) { |
| E_ERROR("Failed to write %d quantization values\n", |
| quant->nvalues); |
| } |
| } |
| } |
|
|
| void |
| lm_trie_quant_free(lm_trie_quant_t * quant) |
| { |
| if (quant->values) |
| ckd_free(quant->values); |
| ckd_free(quant); |
| } |
|
|
| uint8 |
| lm_trie_quant_msize(lm_trie_quant_t * quant) |
| { |
| (void)quant; |
| return 32; |
| } |
|
|
| uint8 |
| lm_trie_quant_lsize(lm_trie_quant_t * quant) |
| { |
| (void)quant; |
| return 16; |
| } |
|
|
| static int |
| weights_comparator(const void *a, const void *b) |
| { |
| return (int) (*(float32 *) a - *(float32 *) b); |
| } |
|
|
| static void |
| make_bins(float32 *values, uint32 values_num, float32 *centers, uint32 bins) |
| { |
| float32 *finish, *start; |
| uint32 i; |
|
|
| qsort(values, values_num, sizeof(*values), &weights_comparator); |
| start = values; |
| for (i = 0; i < bins; i++, centers++, start = finish) { |
| finish = values + (size_t) ((uint64) values_num * (i + 1) / bins); |
| if (finish == start) { |
| |
| *centers = i ? *(centers - 1) : -FLOAT_INF; |
| } |
| else { |
| float32 sum = 0.0f; |
| float32 *ptr; |
| for (ptr = start; ptr != finish; ptr++) { |
| sum += *ptr; |
| } |
| *centers = sum / (float32) (finish - start); |
| } |
| } |
| } |
|
|
| void |
| lm_trie_quant_train(lm_trie_quant_t * quant, int order, uint32 counts, |
| ngram_raw_t * raw_ngrams) |
| { |
| float32 *probs; |
| float32 *backoffs; |
| float32 *centers; |
| uint32 backoff_num; |
| uint32 prob_num; |
| ngram_raw_t *raw_ngrams_end; |
|
|
| probs = (float32 *) ckd_calloc(counts, sizeof(*probs)); |
| backoffs = (float32 *) ckd_calloc(counts, sizeof(*backoffs)); |
| raw_ngrams_end = raw_ngrams + counts; |
|
|
| for (backoff_num = 0, prob_num = 0; raw_ngrams != raw_ngrams_end; |
| raw_ngrams++) { |
| probs[prob_num++] = raw_ngrams->prob; |
| backoffs[backoff_num++] = raw_ngrams->backoff; |
| } |
|
|
| make_bins(probs, prob_num, quant->tables[order - 2][0].begin, |
| 1ULL << quant->prob_bits); |
| centers = quant->tables[order - 2][1].begin; |
| make_bins(backoffs, backoff_num, centers, (1ULL << quant->bo_bits)); |
| ckd_free(probs); |
| ckd_free(backoffs); |
| } |
|
|
| void |
| lm_trie_quant_train_prob(lm_trie_quant_t * quant, int order, uint32 counts, |
| ngram_raw_t * raw_ngrams) |
| { |
| float32 *probs; |
| uint32 prob_num; |
| ngram_raw_t *raw_ngrams_end; |
|
|
| probs = (float32 *) ckd_calloc(counts, sizeof(*probs)); |
| raw_ngrams_end = raw_ngrams + counts; |
|
|
| for (prob_num = 0; raw_ngrams != raw_ngrams_end; raw_ngrams++) { |
| probs[prob_num++] = raw_ngrams->prob; |
| } |
|
|
| make_bins(probs, prob_num, quant->tables[order - 2][0].begin, |
| 1ULL << quant->prob_bits); |
| ckd_free(probs); |
| } |
|
|
| void |
| lm_trie_quant_mwrite(lm_trie_quant_t * quant, bitarr_address_t address, |
| int order_minus_2, float32 prob, float32 backoff) |
| { |
| bitarr_write_int57(address, quant->prob_bits + quant->bo_bits, |
| (uint64) ((bins_encode |
| (&quant->tables[order_minus_2][0], |
| prob) << quant-> |
| bo_bits) | bins_encode(&quant-> |
| tables |
| [order_minus_2] |
| [1], |
| backoff))); |
| } |
|
|
| void |
| lm_trie_quant_lwrite(lm_trie_quant_t * quant, bitarr_address_t address, |
| float32 prob) |
| { |
| bitarr_write_int25(address, quant->prob_bits, |
| (uint32) bins_encode(quant->longest, prob)); |
| } |
|
|
| float32 |
| lm_trie_quant_mboread(lm_trie_quant_t * quant, bitarr_address_t address, |
| int order_minus_2) |
| { |
| return bins_decode(&quant->tables[order_minus_2][1], |
| bitarr_read_int25(address, quant->bo_bits, |
| quant->bo_mask)); |
| } |
|
|
| float32 |
| lm_trie_quant_mpread(lm_trie_quant_t * quant, bitarr_address_t address, |
| int order_minus_2) |
| { |
| address.offset += quant->bo_bits; |
| return bins_decode(&quant->tables[order_minus_2][0], |
| bitarr_read_int25(address, quant->prob_bits, |
| quant->prob_mask)); |
| } |
|
|
| float32 |
| lm_trie_quant_lpread(lm_trie_quant_t * quant, bitarr_address_t address) |
| { |
| return bins_decode(quant->longest, |
| bitarr_read_int25(address, quant->prob_bits, |
| quant->prob_mask)); |
| } |
|
|