File size: 3,832 Bytes
6f00721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from langchain import OpenAI
from langchain.prompts import PromptTemplate, FewShotPromptTemplate
from langchain.chains import LLMChain

EXAMPLES_PROMPT_TEMPLATE = PromptTemplate(
    input_variables=["input", "output"],
    template="Input: {input}\nOutput: {output}"
)

PLAN_MOVE_PROMPT_EXAMPLES = [
    {"input": "The piles contain 3, 5, 7 sticks", "output": "I'll take one stick from pile A"},
    {"input": "The piles contain 2, 5, 7 sticks", "output": "I'll take one stick from pile B"},
    {"input": "The piles contain 2, 5, 7 sticks", "output": "I'll take five stick from pile B"},
    {"input": "The piles contain 1, 2, 3 sticks", "output": "I'll take two sticks from pile C"},
    {"input": "The piles contain 0, 2, 3 sticks", "output": "I'll take one stick from pile C"},
    {"input": "The piles contain 0, 2, 0 sticks", "output": "I'll take two sticks from pile B"},
]

PLAN_MOVE_PROMPT_FROM_STRING_EXAMPLES = FewShotPromptTemplate(
    examples=PLAN_MOVE_PROMPT_EXAMPLES,
    example_prompt=EXAMPLES_PROMPT_TEMPLATE,
    prefix="Nim is a two-player game of strategy in which players take turns removing objects from separate piles. "
           "The goal of the game is to remove the last sticks from a pile when the other piles contain 0 sticks. Each "
           "of these inputs represent a game state. For each of these game states please express a logical move that "
           "consists of taking some number of sticks from a pile. "
           "You may not take any sticks from a pile that contains 0 sticks. "
           "You may not take more sticks from a pile than it contains. "
           "You may only take sticks from one pile. ",
    suffix="Input: {text_game_state}\nOutput:",
    input_variables=["text_game_state"],
    example_separator="\n\n"
)

EXEC_MOVE_PROMPT_EXAMPLES = [
    {"input": "I'll take two sticks from pile A", "output": "0,2"},
    {"input": "I'll take 3 sticks from the first pile", "output": "0,3"},
    {"input": "I'll take two sticks from pile C", "output": "2,2"},
    {"input": "I'll take one stick from the third pile", "output": "2,1"},
    {"input": "From pile B remove 2 sticks", "output": "1,2"},
    {"input": "I'll take the last stick from pile C", "output": "2,1"},
]

EXEC_MOVE_PROMPT_FROM_STRING_EXAMPLES = FewShotPromptTemplate(
    examples=EXEC_MOVE_PROMPT_EXAMPLES,
    example_prompt=EXAMPLES_PROMPT_TEMPLATE,
    prefix="Express every input as two numbers separated by a comma, where the first number is the zero index pile "
           "number and the second number is the number of sticks to remove.",
    suffix="Input: {move_to_express}\nOutput:",
    input_variables=["move_to_express"],
    example_separator="\n\n"
)


def plan_move(text_game_state, temperature, api_key):
    llm = OpenAI(model_name='text-davinci-003', temperature=temperature, max_tokens=100,
                 openai_api_key=api_key)
    llm_chain = LLMChain(llm=llm, prompt=PLAN_MOVE_PROMPT_FROM_STRING_EXAMPLES, verbose=False)
    planned_move = llm_chain.run({'text_game_state': text_game_state}).strip()
    return planned_move


def execute_move(move_to_express, nim_game_env, api_key):
    llm = OpenAI(model_name='text-davinci-003', temperature=0.0, max_tokens=10,
                 openai_api_key=api_key)
    llm_chain = LLMChain(llm=llm, prompt=EXEC_MOVE_PROMPT_FROM_STRING_EXAMPLES, verbose=False)
    step_tuple_str = llm_chain.run({'move_to_express': move_to_express})
    step_tuple = tuple(int(x) for x in step_tuple_str.split(','))
    try:
        step_result = nim_game_env.step(step_tuple)
    except ValueError:
        return "Invalid move!", [0, 0, 0], 0, True, None

    text_observation = "The piles contain " + ", ".join(str(x) for x in step_result[0]) + " sticks."
    return text_observation, step_result[0], step_result[1], step_result[2], step_result[3]