stok-sub-1 / run_stok.py
tyraepaul's picture
Upload 2 files
a3893d8 verified
import json
import random
def strip_prompt(prompt): # used to make it more likely for the prompt to be understood
newprompt = str(prompt).lower()
newprompt = newprompt.replace(".", "")
newprompt = newprompt.replace("[", "")
newprompt = newprompt.replace("]", "")
newprompt = newprompt.replace(":", "")
newprompt = newprompt.replace(",", "")
newprompt = newprompt.replace("\"", "")
newprompt = newprompt.replace("'", "")
newprompt = newprompt.replace("(", "")
newprompt = newprompt.replace(")", "")
newprompt = newprompt.replace(";", "")
newprompt = newprompt.replace("-", "")
newprompt = newprompt.replace("_", "")
newprompt = newprompt.replace("{", "")
newprompt = newprompt.replace("}", "")
newprompt = newprompt.replace("?", "")
newprompt = newprompt.replace("!", "")
newprompt = " ".join(newprompt.split(sep=None))
return newprompt
def strip_text(prompt): # kinda wacky overall
newprompt = str(prompt).lower()
newprompt = " ".join(newprompt.split(sep=None))
return newprompt
model = {"model_data": {}}
def load_model(filename: str):
model["model_data"] = json.loads(open(filename, "r").read())
def symbolize_prompt(prompt): # checks if prompt can be contextualized based on a symbol (currently only math)
symbols = ["+", "-", "/", "*", "x", "X"]
numbers = []
prompt_left = []
prompt_right = []
for x in range(0, 10):
numbers.append(str(x))
prompt = "".join(prompt.split(sep=None)) # remove whitespace
for symbol in symbols:
if symbol in prompt:
listed_prompt = list(prompt)
sym_index = listed_prompt.index(symbol)
i = sym_index
nochar = True
while nochar:
i += 1
try:
if listed_prompt[i] in numbers or listed_prompt[i] == ".":
prompt_right.append(listed_prompt[i])
else:
nochar = False
except IndexError:
nochar = False
i = sym_index
nochar = True
while nochar:
i -= 1
try:
if listed_prompt[i] in numbers or listed_prompt[i] == ".":
prompt_left.append(listed_prompt[i])
else:
nochar = False
except IndexError:
nochar = False
new_prompt = f"{''.join(prompt_left)}{symbol}{''.join(prompt_right)}"
return new_prompt
return None
def version_04_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
raw_outputs = model_data["raw_outputs"]
prompts = model_data["prompts"]
ends = model_data["ends"]
start = ""
topic = None
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
topic = token
break
if topic == None: # use raw outputs
save_prompt = symbolize_prompt(prompt)
if save_prompt != None:
token_now = False
for token in save_prompt.split(sep=None):
if token in prompts:
token_now = True
break
if token_now:
for chunk in version_04_inference(prompt=save_prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty):
yield chunk
else:
outputs = raw_outputs
topic = None
start = split_prompt[-1]
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
outputs = raw_outputs
topic = None
start = split_prompt[-1]
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if outputs.get(topic) != None:
if token in outputs[topic]:
next_token = max(outputs[topic][token], key=outputs[topic][token].get)
outputs[topic][token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
running = False # this is because single token responses seem to break things
def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
raw_outputs = model_data["raw_outputs"]
prompts = model_data["prompts"]
ends = model_data["ends"]
start = ""
topic = None
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
topic = token
break
if topic == None: # use raw outputs
outputs = raw_outputs
topic = None
start = split_prompt[-1]
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if outputs.get(topic) != None:
if token in outputs[topic]:
next_token = max(outputs[topic][token], key=outputs[topic][token].get)
outputs[topic][token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
else:
running = False # this is because single token responses seem to break things
def version_02_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
prompts = model_data["prompts"]
ends = model_data["ends"]
start = ""
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
topic = token
break
else:
topic = random.choice(list(ends))
start = random.choice(list(prompts.keys()))
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
if topic:
if token in ends[topic]:
running = False
def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=1):
tokens_generated = 0
split_prompt = strip_prompt(prompt).split(sep=None)
model_data = model["model_data"]
outputs = model_data["outputs"]
prompts = model_data["prompts"]
start = ""
for token in split_prompt:
if token in prompts:
start = max(prompts[token], key=prompts[token].get)
tokens_generated += 1
running = True
current_token = [start]
while running:
token = current_token[0]
yield f"{token} "
if token in outputs:
next_token = max(outputs[token], key=outputs[token].get)
outputs[token][next_token] -= repetition_penalty
else:
next_token = random.choice(list(outputs.keys()))
current_token[0] = next_token
tokens_generated += 1
if max_tokens != None:
if tokens_generated >= max_tokens:
running = False
def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
# (temperature does not work on versions below 0.5)
model_data = model["model_data"]
model_format = model_data["format"]
if model_data["format"] == "v0.1":
response = version_01_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
if model_data["format"] == "v0.2":
response = version_02_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
if model_data["format"] == "v0.3":
response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk
if model_data["format"] == "v0.4":
response = version_04_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
for chunk in response:
yield chunk