File size: 3,281 Bytes
a1ca2de
 
 
 
3332aa4
a1ca2de
3332aa4
936d161
a1ca2de
 
ceefdf5
 
936d161
ceefdf5
 
 
 
 
 
 
 
 
 
 
3332aa4
 
 
936d161
 
 
 
3332aa4
936d161
 
 
 
 
 
 
 
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3332aa4
 
a1ca2de
 
3332aa4
a1ca2de
3332aa4
 
a1ca2de
936d161
 
3332aa4
 
 
936d161
3332aa4
 
936d161
3332aa4
936d161
 
a1ca2de
 
3332aa4
 
 
a1ca2de
936d161
a1ca2de
936d161
 
3332aa4
936d161
 
 
 
a1ca2de
 
 
 
 
 
 
 
3332aa4
 
 
 
a1ca2de
3332aa4
a1ca2de
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import copy
import json
import string
import random
import asyncio

from modules.llms import get_llm_factory

from pingpong.context import CtxLastWindowStrategy

def add_side_character_to_export(
	characters,	enable, img, 
	name, age, personality, job
):
	if enable:
		characters.append(
			{
				'img': img,
				'name': name
			}
		)

	return characters

def add_side_character(enable, name, age, personality, job, llm_type="PaLM"):
	prompts = get_llm_factory(llm_type).create_prompt_manager().prompts

	cur_side_chars = 1
	prompt = ""
	for idx in range(len(enable)):
		if enable[idx]:
			prompt += prompts['story_gen']['add_side_character'].format(
										cur_side_chars=cur_side_chars,
										name=name[idx],
										job=job[idx],
										age=age[idx],
										personality=personality[idx]
									)
			cur_side_chars += 1
	return "\n" + prompt if prompt else ""

def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
	return ''.join(random.choice(chars) for _ in range(size))

def parse_first_json_code_snippet(code_snippet):
	json_parsed_string = None
	
	try:
		json_parsed_string = json.loads(code_snippet, strict=False)
	except:
		json_start_index = code_snippet.find('```json')
		json_end_index = code_snippet.find('```', json_start_index + 6)

		if json_start_index < 0 or json_end_index < 0:
			raise ValueError('No JSON code snippet found in string.')

		json_code_snippet = code_snippet[json_start_index + 7:json_end_index]
		json_parsed_string = json.loads(json_code_snippet, strict=False)
	finally:
		if json_parsed_string is None:
			raise ValueError('No JSON code snippet found in string.')
		return json_parsed_string

async def retry_until_valid_json(prompt, parameters=None, llm_type="PaLM"):
	response_json = None
	factory = get_llm_factory(llm_type)
	llm_service = factory.create_llm_service()

	for _ in range(3):
		try:
			response, response_txt = await asyncio.wait_for(llm_service.gen_text(prompt, mode="text", parameters=parameters),
													  		timeout=10)

			print(response_txt)
		except asyncio.TimeoutError:
			raise TimeoutError(f"The response time for {llm_type} API exceeded the limit.")
		except Exception as e:
			print(f"{llm_type} API has encountered an error. Retrying...")
			continue
		
		try:
			response_json = parse_first_json_code_snippet(response_txt)
			if not response_json:
				print("Parsing JSON failed. Retrying...")
				continue
		except:
			print("Parsing JSON failed. Retrying...")
			pass
	
	if len(response.filters) > 0:
		raise ValueError(f"{llm_type} API has withheld a response due to content safety concerns.")
	elif response_json is None:
		print("=== Failed to generate valid JSON response. ===")
		print(response_txt)
		raise ValueError("Failed to generate valid JSON response.")
			
	return response_json

def build_prompts(ppm, win_size=3):
	dummy_ppm = copy.deepcopy(ppm)
	lws = CtxLastWindowStrategy(win_size)
	return lws(dummy_ppm)

async def get_chat_response(prompt, ctx=None, llm_type="PaLM"):
	factory = get_llm_factory(llm_type)
	llm_service = factory.create_llm_service()
	parameters = llm_service.make_params(mode="chat", temperature=1.0, top_k=50, top_p=0.9)
	
	_, response_txt = await llm_service.gen_text(
		prompt, 
		parameters=parameters
	)

	return response_txt