Cat0125
add train tab, improve quality
8e637c7
from random import random
import gradio as gr
from datamanager import get_data_v3, models
from generation.words import get_next_word
def find_model(model_name):
for key, model in models.items():
if model['name'] == model_name:
return get_data_v3(key)
raise ValueError('Model %s not found' % model_name)
def generate(user_message, word_count, model_name, stop_chance):
db = find_model(model_name)
message = user_message.lower().strip()
if word_count < 0 or word_count > 300:
return gr.Warning("Invalid word count. It must be between 0 and 300.")
text = ""
curword = ""
prevword = ""
while len(text.split()) < word_count:
prevword = curword
curword = get_next_word(db, message, prevword, text, {})
text += curword + " "
if '.' in curword and random() < stop_chance:
yield text.strip()
break
yield text
def cont(user_message, word_count, model_name):
db = find_model(model_name)
message = user_message.lower().strip()
if not message:
return gr.Warning('No message')
if word_count < 0 or word_count > 450:
raise gr.Error("Invalid word count. It must be between 0 and 450.")
text = message
curword = text.split()[-1]
text += " "
while len(text.split()) < word_count:
prevword = curword
curword = get_next_word(db, message, prevword, text, {})
text += curword + " "
yield text.strip()