Cat0125
add train tab, improve quality
8e637c7
raw
history blame contribute delete
No virus
2.85 kB
from random import choice, choices
WEIGHTS_MAP_HARD = [
100,
10,
8,
6,
2
]
WEIGHTS_MAP_SOFT = [
80,
30,
10,
7,
2
]
def get_next_word_results(db, message, prev_word, text, _):
results = []
if prev_word not in db:
return results
for token in db[prev_word]:
token.score = 0
for context in token.contexts:
if context in message:
token.score += 2
if context in text:
token.score += 1
if ")" in token.word and text.count("(") > text.count(")"):
token.score += 10
if token.score > 0:
results.append(token)
return results
def get_next_word(db, message, prevword, text, conf, repeat=0):
if prevword == '' or '.' in prevword or '?' in prevword or '!' in prevword:
return get_first_word(db, message, text, conf, repeat)
results = get_next_word_results(db, message, prevword, text, conf)
if len(results) == 0:
if repeat >= 1:
return choice(list(db.keys()))
else:
return get_next_word(db, message, prevword, text, conf, repeat + 1)
results = list(sorted(results, key=lambda x: x.score, reverse=True))
total_results = []
max_score = 0
for i in range(min(len(results), 5)):
if max_score == 0:
total_results.append(results[i].word)
max_score = results[i].score
elif max_score == results[i].score:
total_results.append(results[i].word)
if len(total_results) == 0:
return get_next_word(db, message, prevword, text, conf, repeat + 1)
return choice(total_results)
def get_first_word_results(db, message, text, _):
results = []
if '' not in db:
return results
for token in db['']:
token.score = 0
for context in token.contexts:
if context in message:
token.score += 2
if context in text:
token.score += 1
if token.starter:
token.score += 15
if token.score > 0:
results.append(token)
return results
def get_first_word(db, message, text, conf, repeat=0):
results = get_first_word_results(db, message, text, conf)
if len(results) == 0:
if repeat >= 1:
return choice(list(db.keys()))
else:
return get_first_word(db, message, text, conf, repeat + 1)
results = list(sorted(results, key=lambda x: x.score, reverse=True))
total_results = []
weights = []
for i in range(min(len(results), 5)):
total_results.append(results[i].word)
weights.append(WEIGHTS_MAP_SOFT[i])
if len(total_results) == 0:
return get_first_word(db, message, text, conf, repeat + 1)
return (choices(total_results, weights=weights, k=1) or '.')[0]