File size: 4,112 Bytes
a1ca2de
 
 
3332aa4
a1ca2de
 
 
 
 
936d161
 
 
 
3332aa4
a1ca2de
3332aa4
 
 
 
936d161
 
 
 
 
 
a1ca2de
3332aa4
936d161
 
 
 
 
 
a1ca2de
 
 
3332aa4
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936d161
 
 
 
3332aa4
 
a1ca2de
3332aa4
 
 
 
936d161
 
 
 
 
 
a1ca2de
3332aa4
936d161
 
 
 
 
 
 
 
 
 
 
 
a1ca2de
 
 
3332aa4
a1ca2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import re
import gradio as gr
from interfaces import utils
from modules import get_llm_factory

async def plot_gen(
	temperature,
	genre, place, mood,
	side_char_enable1, side_char_enable2, side_char_enable3,
	main_char_name, main_char_age, main_char_personality, main_char_job,
	side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
	side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
	side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
	llm_type="PaLM"
):
	factory = get_llm_factory(llm_type)
	prompts = factory.create_prompt_manager().prompts
	llm_service = factory.create_llm_service()

	side_char_prompt = utils.add_side_character(
		[side_char_enable1, side_char_enable2, side_char_enable3],
		[side_char_name1, side_char_name2, side_char_name3],
		[side_char_job1, side_char_job2, side_char_job3],
		[side_char_age1, side_char_age2, side_char_age3],
		[side_char_personality1, side_char_personality2, side_char_personality3],
	)
	prompt = prompts['plot_gen']['main_plot_gen'].format(
		genre=genre, place=place, mood=mood,
		main_char_name=main_char_name,
		main_char_job=main_char_job,
		main_char_age=main_char_age,
		main_char_personality=main_char_personality,
		side_char_placeholder=side_char_prompt,
	)
	
	print(f"generated prompt:\n{prompt}")
	parameters = llm_service.make_params(mode="text", temperature=temperature, top_k=40, top_p=1.0, max_output_tokens=4096)
	response_json = await utils.retry_until_valid_json(prompt, parameters=parameters)

	return (
		response_json['title'],
		f"## {response_json['title']}",
		response_json['outline']['rising action'],
		response_json['outline']['crisis'],
		response_json['outline']['climax'],
		response_json['outline']['falling action'],
		response_json['outline']['denouement'],
	)


async def first_story_gen(
	title, 
	rising_action, crisis, climax, falling_action, denouement,
	genre, place, mood,
	side_char_enable1, side_char_enable2, side_char_enable3,
	main_char_name, main_char_age, main_char_personality, main_char_job,
	side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
	side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
	side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
	cursors, cur_cursor,
	llm_type="PaLM"
):
	factory = get_llm_factory(llm_type)
	prompts = factory.create_prompt_manager().prompts
	llm_service = factory.create_llm_service()

	side_char_prompt = utils.add_side_character(
		[side_char_enable1, side_char_enable2, side_char_enable3],
		[side_char_name1, side_char_name2, side_char_name3],
		[side_char_job1, side_char_job2, side_char_job3],
		[side_char_age1, side_char_age2, side_char_age3],
		[side_char_personality1, side_char_personality2, side_char_personality3],
	)
	prompt = prompts['plot_gen']['first_story_gen'].format(
		genre=genre, place=place, mood=mood,
		main_char_name=main_char_name,
		main_char_job=main_char_job,
		main_char_age=main_char_age,
		main_char_personality=main_char_personality,
		side_char_placeholder=side_char_prompt,
		title=title,
		rising_action=rising_action,
		crisis=crisis,
		climax=climax,
		falling_action=falling_action,
		denouement=denouement,
	)

	print(f"generated prompt:\n{prompt}")
	parameters = llm_service.make_params(mode="text", temperature=1.0, top_k=40, top_p=1.0, max_output_tokens=4096)
	response_json = await utils.retry_until_valid_json(prompt, parameters=parameters)

	chapter_title = response_json["chapter_title"]
	pattern = r"Chapter\s+\d+\s*[:.]"
	chapter_title = re.sub(pattern, "", chapter_title)

	cursors.append({
		"title": chapter_title,
		"plot_type": "rising action",
		"story": "\n\n".join(response_json["paragraphs"])
	})

	return (
		f"### {chapter_title} (\"rising action\")",
		"\n\n".join(response_json["paragraphs"]),
		cursors,
		cur_cursor, 
		gr.update(interactive=True),
		gr.update(interactive=True),
		gr.update(value=response_json["actions"][0], interactive=True),
		gr.update(value=response_json["actions"][1], interactive=True),
		gr.update(value=response_json["actions"][2], interactive=True),
	)