Spaces:
Sleeping
Sleeping
File size: 2,109 Bytes
ee00a52 24510fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnableLambda
from operator import itemgetter
from langchain.output_parsers import PydanticOutputParser
from .output_parser import SongDescriptions
from langchain.llms.base import LLM
class LLMChain:
def __init__(self, llm_model: LLM) -> None:
self.llm_model = llm_model
self.parser = PydanticOutputParser(pydantic_object=SongDescriptions)
self.full_chain = self._create_llm_chain()
def _get_output_format(self, _):
return self.parser.get_format_instructions()
def _create_llm_chain(self):
prompt_response = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant, helping the user to turn a music playlist text description into four separate song descriptions that are probably contained in the playlist. Try to be specific with descriptions. Make sure all 4 song descriptions are similar.\n"),
("system", "{format_instructions}\n"),
("human", "Playlist description: {description}.\n"),
# ("human", "Song descriptions:"),
])
# prompt = PromptTemplate(
# template="You are an AI assistant, helping the user to turn a music playlist text description into three separate generic song descriptions that are probably contained in the playlist.\n{format_instructions}\n{description}\n",
# input_variables=["description"],
# partial_variables={"format_instructions": self.parser.get_format_instructions()},
# )
full_chain = (
{
"format_instructions": RunnableLambda(self._get_output_format),
"description": itemgetter("description"),
}
| prompt_response
| self.llm_model
)
return full_chain
def process_user_description(self, user_input):
output = self.full_chain.invoke(
{
"description": user_input
}
).replace("\\", '')
return self.parser.parse(output)
|