Spaces:
Running
Running
from random import random | |
import gradio as gr | |
from datamanager import get_data_v3, models | |
from generation.words import get_next_word | |
def find_model(model_name): | |
for key, model in models.items(): | |
if model['name'] == model_name: | |
return get_data_v3(key) | |
raise ValueError('Model %s not found' % model_name) | |
def generate(user_message, word_count, model_name, stop_chance): | |
db = find_model(model_name) | |
message = user_message.lower().strip() | |
if word_count < 0 or word_count > 300: | |
return gr.Warning("Invalid word count. It must be between 0 and 300.") | |
text = "" | |
curword = "" | |
prevword = "" | |
while len(text.split()) < word_count: | |
prevword = curword | |
curword = get_next_word(db, message, prevword, text, {}) | |
text += curword + " " | |
if '.' in curword and random() < stop_chance: | |
yield text.strip() | |
break | |
yield text | |
def cont(user_message, word_count, model_name): | |
db = find_model(model_name) | |
message = user_message.lower().strip() | |
if not message: | |
return gr.Warning('No message') | |
if word_count < 0 or word_count > 450: | |
raise gr.Error("Invalid word count. It must be between 0 and 450.") | |
text = message | |
curword = text.split()[-1] | |
text += " " | |
while len(text.split()) < word_count: | |
prevword = curword | |
curword = get_next_word(db, message, prevword, text, {}) | |
text += curword + " " | |
yield text.strip() | |