Spaces:
Running
Running
from random import choice, choices | |
WEIGHTS_MAP_HARD = [ | |
100, | |
10, | |
8, | |
6, | |
2 | |
] | |
WEIGHTS_MAP_SOFT = [ | |
80, | |
30, | |
10, | |
7, | |
2 | |
] | |
def get_next_word_results(db, message, prev_word, text, _): | |
results = [] | |
if prev_word not in db: | |
return results | |
for token in db[prev_word]: | |
token.score = 0 | |
for context in token.contexts: | |
if context in message: | |
token.score += 2 | |
if context in text: | |
token.score += 1 | |
if ")" in token.word and text.count("(") > text.count(")"): | |
token.score += 10 | |
if token.score > 0: | |
results.append(token) | |
return results | |
def get_next_word(db, message, prevword, text, conf, repeat=0): | |
if prevword == '' or '.' in prevword or '?' in prevword or '!' in prevword: | |
return get_first_word(db, message, text, conf, repeat) | |
results = get_next_word_results(db, message, prevword, text, conf) | |
if len(results) == 0: | |
if repeat >= 1: | |
return choice(list(db.keys())) | |
else: | |
return get_next_word(db, message, prevword, text, conf, repeat + 1) | |
results = list(sorted(results, key=lambda x: x.score, reverse=True)) | |
total_results = [] | |
max_score = 0 | |
for i in range(min(len(results), 5)): | |
if max_score == 0: | |
total_results.append(results[i].word) | |
max_score = results[i].score | |
elif max_score == results[i].score: | |
total_results.append(results[i].word) | |
if len(total_results) == 0: | |
return get_next_word(db, message, prevword, text, conf, repeat + 1) | |
return choice(total_results) | |
def get_first_word_results(db, message, text, _): | |
results = [] | |
if '' not in db: | |
return results | |
for token in db['']: | |
token.score = 0 | |
for context in token.contexts: | |
if context in message: | |
token.score += 2 | |
if context in text: | |
token.score += 1 | |
if token.starter: | |
token.score += 15 | |
if token.score > 0: | |
results.append(token) | |
return results | |
def get_first_word(db, message, text, conf, repeat=0): | |
results = get_first_word_results(db, message, text, conf) | |
if len(results) == 0: | |
if repeat >= 1: | |
return choice(list(db.keys())) | |
else: | |
return get_first_word(db, message, text, conf, repeat + 1) | |
results = list(sorted(results, key=lambda x: x.score, reverse=True)) | |
total_results = [] | |
weights = [] | |
for i in range(min(len(results), 5)): | |
total_results.append(results[i].word) | |
weights.append(WEIGHTS_MAP_SOFT[i]) | |
if len(total_results) == 0: | |
return get_first_word(db, message, text, conf, repeat + 1) | |
return (choices(total_results, weights=weights, k=1) or '.')[0] | |