Spaces:
Runtime error
Runtime error
#!python | |
# -*- coding: utf-8 -*- | |
# @author: Kun | |
import torch | |
import random | |
from sentence_transformers import util | |
from utils import get_content_between_a_b, get_api_response | |
from prompts.chatgpt_query import get_input_text | |
from global_config import lang_opt | |
class RecurrentGPT: | |
def __init__(self, input, short_memory, long_memory, memory_index, embedder): | |
print("AIWriter loaded by RecurrentGPT") | |
self.input = input | |
self.short_memory = short_memory | |
self.long_memory = long_memory | |
self.embedder = embedder | |
if self.long_memory and not memory_index: | |
self.memory_index = self.embedder.encode( | |
self.long_memory, convert_to_tensor=True) | |
self.output = {} | |
def prepare_input(self, new_character_prob=0.1, top_k=2): | |
input_paragraph = self.input["output_paragraph"] | |
input_instruction = self.input["output_instruction"] | |
instruction_embedding = self.embedder.encode( | |
input_instruction, convert_to_tensor=True) | |
# get the top 3 most similar paragraphs from memory | |
memory_scores = util.cos_sim( | |
instruction_embedding, self.memory_index)[0] | |
top_k_idx = torch.topk(memory_scores, k=top_k)[1] | |
top_k_memory = [self.long_memory[idx] for idx in top_k_idx] | |
# combine the top 3 paragraphs | |
input_long_term_memory = '\n'.join( | |
[f"Related Paragraphs {i+1} :" + selected_memory for i, selected_memory in enumerate(top_k_memory)]) | |
# randomly decide if a new character should be introduced | |
if random.random() < new_character_prob: | |
new_character_prompt = f"If it is reasonable, you can introduce a new character in the output paragrah and add it into the memory." | |
else: | |
new_character_prompt = "" | |
input_text = get_input_text(lang_opt, self.short_memory, input_paragraph, input_instruction, input_long_term_memory, new_character_prompt) | |
return input_text | |
def parse_output(self, output): | |
try: | |
output_paragraph = get_content_between_a_b( | |
'Output Paragraph:', 'Output Memory', output) | |
output_memory_updated = get_content_between_a_b( | |
'Updated Memory:', 'Output Instruction:', output) | |
self.short_memory = output_memory_updated | |
ins_1 = get_content_between_a_b( | |
'Instruction 1:', 'Instruction 2', output) | |
ins_2 = get_content_between_a_b( | |
'Instruction 2:', 'Instruction 3', output) | |
lines = output.splitlines() | |
# content of Instruction 3 may be in the same line with I3 or in the next line | |
if lines[-1] != '\n' and lines[-1].startswith('Instruction 3'): | |
ins_3 = lines[-1][len("Instruction 3:"):] | |
elif lines[-1] != '\n': | |
ins_3 = lines[-1] | |
output_instructions = [ins_1, ins_2, ins_3] | |
assert len(output_instructions) == 3 | |
output = { | |
"input_paragraph": self.input["output_paragraph"], | |
"output_memory": output_memory_updated, # feed to human | |
"output_paragraph": output_paragraph, | |
"output_instruction": [instruction.strip() for instruction in output_instructions] | |
} | |
return output | |
except: | |
return None | |
def step(self, response_file=None): | |
prompt = self.prepare_input() | |
print(prompt+'\n'+'\n') | |
response = get_api_response(prompt) | |
self.output = self.parse_output(response) | |
while self.output == None: | |
response = get_api_response(prompt) | |
self.output = self.parse_output(response) | |
if response_file: | |
with open(response_file, 'a', encoding='utf-8') as f: | |
f.write(f"Writer's output here:\n{response}\n\n") | |
self.long_memory.append(self.input["output_paragraph"]) | |
self.memory_index = self.embedder.encode( | |
self.long_memory, convert_to_tensor=True) | |