grammar / app.py
briankchan's picture
Set initial max lines
f7f4627
# based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py
import collections
import os
from itertools import islice
from queue import Queue
from anyio.from_thread import start_blocking_portal
import gradio as gr
from diff_match_patch import diff_match_patch
from langchain.chains import LLMChain
from langchain.chat_models import PromptLayerChatOpenAI, ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import HumanMessage
from util import SyncStreamingLLMCallbackHandler, concatenate_generators
GRAMMAR_PROMPT = "Proofread for grammar and spelling without adding new paragraphs:\n{content}"
INTRO_PROMPT = """These are the parts of a good introductory paragraph:
1. Introductory information
2. The stage of human development of the main character
3. Summary of story
4. Thesis statement (this should also provide an overview the essay structure or topics that may be covered in each paragraph)
For each part, put a quote of the sentences from the following paragraph that fulfil that part and say how confident you are (percentage). If you're not confident, explain why.
---
Example output format:
Thesis statement and outline:
"Sentence A. Sentence B"
Score: X%. Feedback goes here.
---
Intro paragraph:
{content}"""
BODY_PROMPT1 = """You are a university English teacher. Complete the following tasks for the following essay paragraph about a book:
1. Topic sentence: Identify the topic sentence and determine whether it introduces an argument
2. Key points: Outline a bullet list of key points
3. Supporting evidence: Give a bullet list of parts of the paragraph that use quotes or other textual evidence from the book
{content}"""
BODY_PROMPT2 = """4. Give advice on how the topic sentence could be made stronger or clearer
5. In a bullet list, state how each key point supports the topic (or if any doesn't support it)
6. In a bullet list for each supporting evidence, state which key point the evidence supports.
"""
BODY_PROMPT3 = """Briefly summarize "{title}". Then, in a bullet list for each supporting evidence you liisted above, state if it describes an event/detail from the "{title}" or if it's from outside sources.
Use this output format:
[summary]
----
- [supporting evidence 1] - book
- [supporting evidence 2] - outside source"""
BODY_DESCRIPTION = """1. identifies the topic sentence
2. outlines key points
3. checks for supporting evidence (e.g., quotes, summaries, and concrete details)
4. suggests topic sentence improvements
5. checks that the key points match the paragraph topic
6. determines which key point each piece of evidence supports
7. checks whether each evidence is from the book or from an outside source"""
def is_empty(s: str):
return len(s) == 0 or s.isspace()
def check_content(s: str):
if is_empty(s):
raise gr.exceptions.Error('Please input some text before running.')
def load_chain(api_key, api_type):
if api_key == "" or api_key.isspace():
if api_type == "OpenAI":
api_key = os.environ.get("OPENAI_API_KEY", None)
elif api_type == "Azure OpenAI":
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
else:
raise RuntimeError("Unknown API type? " + api_type)
if api_key:
shared_args = {
"temperature": 0,
"model_name": "gpt-3.5-turbo",
"api_key": api_key, # deliberately not use "openai_api_key" and other openai args since those apply globally
"pl_tags": ["grammar"],
"streaming": True,
}
if api_type == "OpenAI":
llm = PromptLayerChatOpenAI(**shared_args)
elif api_type == "Azure OpenAI":
llm = PromptLayerChatOpenAI(
api_type = "azure",
api_base = os.environ.get("AZURE_OPENAI_API_BASE", None),
api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"),
engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None),
**shared_args
)
prompt1 = PromptTemplate(
input_variables=["content"],
template=GRAMMAR_PROMPT
)
chain = LLMChain(llm=llm,
prompt=prompt1,
memory=ConversationBufferMemory())
chain_intro = LLMChain(llm=llm,
prompt=PromptTemplate(
input_variables=["content"],
template=INTRO_PROMPT
),
memory=ConversationBufferMemory())
chain_body1 = LLMChain(llm=llm,
prompt=PromptTemplate(
input_variables=["content"],
template=BODY_PROMPT1
),
memory=ConversationBufferMemory())
return chain, llm, chain_intro, chain_body1
def run_diff(content, chain: LLMChain):
check_content(content)
chain.memory.clear()
edited = chain.run(content)
return diff_words(content, edited) + (edited,)
# https://github.com/hwchase17/langchain/issues/2428#issuecomment-1512280045
def run(content, chain: LLMChain):
check_content(content)
chain.memory.clear()
q = Queue()
job_done = object()
def task():
result = chain.run(content, callbacks=[SyncStreamingLLMCallbackHandler(q)])
q.put(job_done)
return result
with start_blocking_portal() as portal:
portal.start_task_soon(task)
output = ""
while True:
next_token = q.get(True, timeout=10)
if next_token is job_done:
break
output += next_token
yield output
# TODO share code with above
def run_followup(followup_question, input_vars, chain, chat: ChatOpenAI):
check_content(followup_question)
history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m
for m in chain.memory.chat_memory.messages]
prompt = ChatPromptTemplate.from_messages([
*history,
HumanMessagePromptTemplate.from_template(followup_question)])
messages = prompt.format_prompt(**input_vars).to_messages()
q = Queue()
job_done = object()
def task():
result = chat.generate([messages], callbacks=[SyncStreamingLLMCallbackHandler(q)])
q.put(job_done)
return result.generations[0][0].message.content
with start_blocking_portal() as portal:
portal.start_task_soon(task)
output = ""
while True:
next_token = q.get(True, timeout=10)
if next_token is job_done:
break
output += next_token
yield output
def run_body(content, title, chain, llm):
check_content(content) # note: run() also checks, but the error doesn't get shown in the UI?
if not title:
return "Please enter the book title."
yield from concatenate_generators(
run(content, chain),
"\n\n",
run_followup(BODY_PROMPT2, {}, chain, llm),
"\n\n7. Whether supporting evidence is from the book:",
(output.split("----")[-1] for output in run_followup(BODY_PROMPT3, {"title": title}, chain, llm))
)
def run_custom(content, llm, prompt):
chain = LLMChain(llm=llm,
memory=ConversationBufferMemory(),
prompt=PromptTemplate(
input_variables=["content"],
template=prompt
))
return chain.run(content), chain
# not currently used
def split_paragraphs(text):
return [(x, x != "" and not x.startswith("#") and not x.isspace()) for x in text.split("\n")]
def sliding_window(iterable, n):
# sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
it = iter(iterable)
window = collections.deque(islice(it, n), maxlen=n)
if len(window) == n:
yield tuple(window)
for x in it:
window.append(x)
yield tuple(window)
dmp = diff_match_patch()
def diff_words(content, edited):
before = []
after = []
changes = []
change_count = 0
changed = False
diff = dmp.diff_main(content, edited)
dmp.diff_cleanupSemantic(diff)
diff += [(None, None)]
for [(change, text), (next_change, next_text)] in sliding_window(diff, 2):
if change == 0:
before.append((text, None))
after.append((text, None))
else:
if change == -1 and next_change == 1:
change_count += 1
before.append((text, str(change_count)))
after.append((next_text, str(change_count)))
changes.append((text, next_text))
changed = True
elif change == -1:
before.append((text, "-"))
elif change == 1:
if changed:
changed = False
else:
after.append((text, "+"))
else:
raise Exception("Unknown change type: " + change)
return before, after, changes
def get_parts(arr, start, end):
return "".join(arr[start:end])
CHANGES = {
"-": "remove",
"+": "add",
# "→": "change"
}
def select_diff(evt: gr.SelectData, changes):
text, change = evt.value
if not change:
return
change_text = CHANGES.get(change, None)
if change_text:
return f"Why is it better to {change_text} [{text}]?"
# if change == "→":
else:
# clicked = evt.target
# if clicked.label == "Before":
# original = text
# else:
# edited = text
original, edited = changes[int(change) - 1]
# original, edited = text.split("→")
return f"Why is it better to change [{original}] to [{edited}]?"
demo = gr.Blocks(css="""
.diff-component {
white-space: pre-wrap;
}
.diff-component .textspan.hl {
white-space: normal;
}
""")
with demo:
# api_key = gr.Textbox(
# placeholder="Paste your OpenAPI API key here (sk-...)",
# show_label=False,
# lines=1,
# type="password"
# )
api_key = gr.State("")
gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>""")
gr.Markdown("""Paste a paragraph below, and then choose one of the modes to generate feedback.""")
content = gr.Textbox(
label="Paragraph"
)
with gr.Tab("Grammar/Spelling"):
gr.Markdown("Suggests grammar and spelling revisions.")
submit = gr.Button(
value="Revise",
).style(full_width=False)
with gr.Row():
output_before = gr.HighlightedText(
label="Before",
combine_adjacent=True,
elem_classes="diff-component"
).style(color_map={
"-": "red",
# "→": "yellow",
})
output_after = gr.HighlightedText(
label="After",
combine_adjacent=True,
elem_classes="diff-component"
).style(color_map={
"+": "green",
# "→": "yellow",
})
followup_question = gr.Textbox(
label="Follow-up Question",
)
followup_submit = gr.Button(
value="Ask"
).style(full_width=False)
followup_answer = gr.Textbox(
label="Answer"
)
with gr.Tab("Intro"):
gr.Markdown("Checks for the key components of an introductory paragraph.")
submit_intro = gr.Button(
value="Run"
).style(full_width=False)
output_intro = gr.Textbox(
label="Output",
lines=1000,
max_lines=1000
)
with gr.Tab("Body Paragraph"):
gr.Markdown(BODY_DESCRIPTION)
title = gr.Textbox(
label="Book Title"
)
submit_body = gr.Button(
value="Run"
).style(full_width=False)
output_body = gr.Textbox(
label="Output",
lines=1000,
max_lines=1000
)
# with gr.Tab("Custom prompt"):
# gr.Markdown("This mode is for testing and debugging.")
# prompt = gr.Textbox(
# label="Prompt",
# value=GRAMMAR_PROMPT,
# lines=2
# )
# submit_custom = gr.Button(
# value="Run"
# ).style(full_width=False)
# output_custom = gr.Textbox(
# label="Output"
# )
# followup_custom = gr.Textbox(
# label="Follow-up Question"
# )
# followup_answer_custom = gr.Textbox(
# label="Answer"
# )
with gr.Tab("Settings"):
api_type = gr.Radio(
["OpenAI", "Azure OpenAI"],
value="OpenAI",
label="Server",
info="You can try changing this if responses are slow."
)
changes = gr.State()
edited = gr.State()
chain = gr.State()
llm = gr.State()
chain_intro = gr.State()
chain_body1 = gr.State()
chain_custom = gr.State()
# api_key.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])
api_type.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])
inputs = [content, chain]
outputs = [output_before, output_after, changes, edited]
# content.submit(run_diff, inputs=inputs, outputs=outputs)
submit.click(run_diff, inputs=inputs, outputs=outputs)
output_before.select(select_diff, changes, followup_question)
output_after.select(select_diff, changes, followup_question)
empty_input = gr.State({})
inputs2 = [followup_question, empty_input, chain, llm]
outputs2 = followup_answer
followup_question.submit(run_followup, inputs2, outputs2)
followup_submit.click(run_followup, inputs2, outputs2)
submit_intro.click(run, [content, chain_intro], output_intro)
submit_body.click(run_body, [content, title, chain_body1, llm], output_body) # body part A only
# submit_custom.click(run_custom, [content, llm, prompt], [output_custom, chain_custom]) # TODO standardize api--return memory instead of using chain?
# followup_custom.submit(run_followup, [followup_custom, empty_input, chain_custom, llm], followup_answer_custom)
demo.load(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])
port = os.environ.get("SERVER_PORT", None)
if port:
port = int(port)
demo.queue()
demo.launch(debug=True, server_port=port, prevent_thread_lock=True)