styles-scribble-demo / generate_text.py
daniellefranca96's picture
Add application file
1c86ad8
raw
history blame
No virus
3.1 kB
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain, SequentialChain
from langchain.chat_models import ChatAnthropic
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFaceHub
from langchain.prompts import (
PromptTemplate,
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
class GenerateStyleText:
example: str
prompt: str
llm: BaseLanguageModel
def __init__(self, example=None, prompt=None, llm=None):
self.example = example
self.prompt = prompt
self.llm = llm
def set_imp_llm(self, model):
if model == 'GPT3':
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo-16k")
elif model == "GPT4":
self.llm = ChatOpenAI(model_name="gpt-4")
elif model == "Claude":
self.llm = ChatAnthropic()
else:
self.llm = HuggingFaceHub(repo_id=model)
def run(self):
return self.process()
def process(self):
seq_chain = SequentialChain(
chains=[self.get_extract_tone_chain(), self.get_generate_text_chain(self.prompt),
self.get_apply_style_chain()],
input_variables=["text"], verbose=True)
result = seq_chain({'text': self.example, "style": ""})
return str(result.get('result'))
def create_chain(self, chat_prompt, output_key):
return LLMChain(llm=self.llm,
prompt=chat_prompt, output_key=output_key)
def get_extract_tone_chain(self):
template = """Based on the tone and writing style in the seed text, create a style guide for a blog or
publication that captures the essence of the seed’s tone. Emphasize engaging techniques that help readers
feel connected to the content.
"""
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
return self.create_chain(chat_prompt, "style")
def get_generate_text_chain(self, prompt):
template = """Generate a text following the user_request(use same language of the request):
{user_request}
""".replace("{user_request}", prompt)
return self.create_chain(PromptTemplate.from_template(template),
"generated_text")
def get_apply_style_chain(self):
template = """STYLE:
{style}
REWRITE THE TEXT BELLOW APPLYING THE STYLE ABOVE(use same language of the request),
ONLY GENERATE NEW TEXT BASED ON THE STYLE CONTEXT, DO NOT COPY STYLE EXACT PARTS:
{generated_text}
"""
prompt = PromptTemplate.from_template(template=template)
prompt.partial(style="")
return self.create_chain(prompt, "result")