|
import json
|
|
import random
|
|
|
|
def strip_prompt(prompt):
|
|
newprompt = str(prompt).lower()
|
|
newprompt = newprompt.replace(".", "")
|
|
newprompt = newprompt.replace("[", "")
|
|
newprompt = newprompt.replace("]", "")
|
|
newprompt = newprompt.replace(":", "")
|
|
newprompt = newprompt.replace(",", "")
|
|
newprompt = newprompt.replace("\"", "")
|
|
newprompt = newprompt.replace("'", "")
|
|
newprompt = newprompt.replace("(", "")
|
|
newprompt = newprompt.replace(")", "")
|
|
newprompt = newprompt.replace(";", "")
|
|
newprompt = newprompt.replace("-", "")
|
|
newprompt = newprompt.replace("_", "")
|
|
newprompt = newprompt.replace("{", "")
|
|
newprompt = newprompt.replace("}", "")
|
|
newprompt = newprompt.replace("?", "")
|
|
newprompt = newprompt.replace("!", "")
|
|
newprompt = " ".join(newprompt.split(sep=None))
|
|
return newprompt
|
|
|
|
def strip_text(prompt):
|
|
newprompt = str(prompt).lower()
|
|
newprompt = " ".join(newprompt.split(sep=None))
|
|
return newprompt
|
|
|
|
model = {"model_data": {}}
|
|
|
|
def load_model(filename: str):
|
|
model["model_data"] = json.loads(open(filename, "r").read())
|
|
|
|
def symbolize_prompt(prompt):
|
|
symbols = ["+", "-", "/", "*", "x", "X"]
|
|
numbers = []
|
|
prompt_left = []
|
|
prompt_right = []
|
|
for x in range(0, 10):
|
|
numbers.append(str(x))
|
|
prompt = "".join(prompt.split(sep=None))
|
|
for symbol in symbols:
|
|
if symbol in prompt:
|
|
listed_prompt = list(prompt)
|
|
sym_index = listed_prompt.index(symbol)
|
|
i = sym_index
|
|
nochar = True
|
|
while nochar:
|
|
i += 1
|
|
try:
|
|
if listed_prompt[i] in numbers or listed_prompt[i] == ".":
|
|
prompt_right.append(listed_prompt[i])
|
|
else:
|
|
nochar = False
|
|
except IndexError:
|
|
nochar = False
|
|
i = sym_index
|
|
nochar = True
|
|
while nochar:
|
|
i -= 1
|
|
try:
|
|
if listed_prompt[i] in numbers or listed_prompt[i] == ".":
|
|
prompt_left.append(listed_prompt[i])
|
|
else:
|
|
nochar = False
|
|
except IndexError:
|
|
nochar = False
|
|
new_prompt = f"{''.join(prompt_left)}{symbol}{''.join(prompt_right)}"
|
|
return new_prompt
|
|
return None
|
|
|
|
|
|
def version_04_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
|
|
tokens_generated = 0
|
|
split_prompt = strip_prompt(prompt).split(sep=None)
|
|
model_data = model["model_data"]
|
|
outputs = model_data["outputs"]
|
|
raw_outputs = model_data["raw_outputs"]
|
|
prompts = model_data["prompts"]
|
|
ends = model_data["ends"]
|
|
start = ""
|
|
topic = None
|
|
for token in split_prompt:
|
|
if token in prompts:
|
|
start = max(prompts[token], key=prompts[token].get)
|
|
topic = token
|
|
break
|
|
if topic == None:
|
|
save_prompt = symbolize_prompt(prompt)
|
|
if save_prompt != None:
|
|
token_now = False
|
|
for token in save_prompt.split(sep=None):
|
|
if token in prompts:
|
|
token_now = True
|
|
break
|
|
if token_now:
|
|
for chunk in version_04_inference(prompt=save_prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty):
|
|
yield chunk
|
|
else:
|
|
outputs = raw_outputs
|
|
topic = None
|
|
start = split_prompt[-1]
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
else:
|
|
outputs = raw_outputs
|
|
topic = None
|
|
start = split_prompt[-1]
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
else:
|
|
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if outputs.get(topic) != None:
|
|
if token in outputs[topic]:
|
|
next_token = max(outputs[topic][token], key=outputs[topic][token].get)
|
|
outputs[topic][token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
else:
|
|
running = False
|
|
|
|
def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
|
|
tokens_generated = 0
|
|
split_prompt = strip_prompt(prompt).split(sep=None)
|
|
model_data = model["model_data"]
|
|
outputs = model_data["outputs"]
|
|
raw_outputs = model_data["raw_outputs"]
|
|
prompts = model_data["prompts"]
|
|
ends = model_data["ends"]
|
|
start = ""
|
|
topic = None
|
|
for token in split_prompt:
|
|
if token in prompts:
|
|
start = max(prompts[token], key=prompts[token].get)
|
|
topic = token
|
|
break
|
|
if topic == None:
|
|
outputs = raw_outputs
|
|
topic = None
|
|
start = split_prompt[-1]
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
else:
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if outputs.get(topic) != None:
|
|
if token in outputs[topic]:
|
|
next_token = max(outputs[topic][token], key=outputs[topic][token].get)
|
|
outputs[topic][token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
else:
|
|
running = False
|
|
|
|
def version_02_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
|
|
tokens_generated = 0
|
|
split_prompt = strip_prompt(prompt).split(sep=None)
|
|
model_data = model["model_data"]
|
|
outputs = model_data["outputs"]
|
|
prompts = model_data["prompts"]
|
|
ends = model_data["ends"]
|
|
start = ""
|
|
for token in split_prompt:
|
|
if token in prompts:
|
|
start = max(prompts[token], key=prompts[token].get)
|
|
topic = token
|
|
break
|
|
else:
|
|
topic = random.choice(list(ends))
|
|
start = random.choice(list(prompts.keys()))
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
if topic:
|
|
if token in ends[topic]:
|
|
running = False
|
|
|
|
def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
|
|
tokens_generated = 0
|
|
split_prompt = strip_prompt(prompt).split(sep=None)
|
|
model_data = model["model_data"]
|
|
outputs = model_data["outputs"]
|
|
prompts = model_data["prompts"]
|
|
start = ""
|
|
for token in split_prompt:
|
|
if token in prompts:
|
|
start = max(prompts[token], key=prompts[token].get)
|
|
tokens_generated += 1
|
|
running = True
|
|
current_token = [start]
|
|
while running:
|
|
token = current_token[0]
|
|
yield f"{token} "
|
|
if token in outputs:
|
|
next_token = max(outputs[token], key=outputs[token].get)
|
|
outputs[token][next_token] -= repetition_penalty
|
|
else:
|
|
next_token = random.choice(list(outputs.keys()))
|
|
current_token[0] = next_token
|
|
tokens_generated += 1
|
|
if max_tokens != None:
|
|
if tokens_generated >= max_tokens:
|
|
running = False
|
|
|
|
def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
|
|
|
|
model_data = model["model_data"]
|
|
model_format = model_data["format"]
|
|
if model_data["format"] == "v0.1":
|
|
response = version_01_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
|
for chunk in response:
|
|
yield chunk
|
|
|
|
if model_data["format"] == "v0.2":
|
|
response = version_02_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
|
for chunk in response:
|
|
yield chunk
|
|
|
|
if model_data["format"] == "v0.3":
|
|
response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
|
for chunk in response:
|
|
yield chunk
|
|
|
|
if model_data["format"] == "v0.4":
|
|
response = version_04_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
|
|
for chunk in response:
|
|
yield chunk
|
|
|
|
|