File size: 2,108 Bytes
bb48ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import traceback
from typing import List, Tuple

from loguru import logger
from pydantic import BaseModel, Field
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.output_parsers import (
    PydanticOutputParser,
    OutputFixingParser,
)
# LLM.py ζ˜―ζˆ‘θ‡ͺε·±ηš„θ―­θ¨€ζ¨‘εž‹οΌŒδ½ ε―δ»₯η›΄ζŽ₯使用 openai ηš„
# from langchain.llms.openai import OpenAIChat
from LLM import OpenAIChat

TEMPLATE = """\

please write a story at least 5 sentences long, using the words [{words}].



{format}



Attention! The words should be highlighted, surrounded by "`". Therefore, the story should be in the following format.

English: ... `word1` ... `word2` ...

Chinese: ... `单词1` ... `单词2` ...

"""

class Story(BaseModel):
    story: str = Field(description="the story")
    translated_story: str = Field(description="the translated story")

llm = OpenAIChat(model_name="gpt-3.5-turbo", temperature=0.3)
parser = PydanticOutputParser(pydantic_object=Story)
prompt_template = PromptTemplate(
    template=TEMPLATE,
    input_variables=["words"],
    partial_variables={
        "format": parser.get_format_instructions(),
    }
)
parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
chain = LLMChain(
    llm=llm,
    prompt=prompt_template,
    output_parser=parser,
    verbose=False,
)

def tell_story(words: List[str]):
    count = 0
    while count < 10:
        count += 1
        try:
            resp: Story = chain.run(", ".join(words))
            if len(resp.story.strip()) == 0:
                continue
            if len(resp.translated_story.strip()) == 0:
                continue
            return resp
        except Exception as e:
            logger.error(e)
            logger.error(traceback.format_exc())
            logger.error("retrying...")
            continue
    return Story(story="", translated_story="")

def generate_story_and_translated_story(words: List[str]) -> Tuple[str, str]:
    resp = tell_story(words)
    return resp.story, resp.translated_story