Spaces:
Runtime error
Runtime error
# based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py | |
import collections | |
import os | |
from itertools import islice | |
from queue import Queue | |
from anyio.from_thread import start_blocking_portal | |
import gradio as gr | |
from diff_match_patch import diff_match_patch | |
from langchain.chains import LLMChain | |
from langchain.chat_models import PromptLayerChatOpenAI, ChatOpenAI | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate | |
from langchain.schema import HumanMessage | |
from util import SyncStreamingLLMCallbackHandler, concatenate_generators | |
GRAMMAR_PROMPT = "Proofread for grammar and spelling without adding new paragraphs:\n{content}" | |
INTRO_PROMPT = """These are the parts of a good introductory paragraph: | |
1. Introductory information | |
2. The stage of human development of the main character | |
3. Summary of story | |
4. Thesis statement (this should also provide an overview the essay structure or topics that may be covered in each paragraph) | |
For each part, put a quote of the sentences from the following paragraph that fulfil that part and say how confident you are (percentage). If you're not confident, explain why. | |
--- | |
Example output format: | |
Thesis statement and outline: | |
"Sentence A. Sentence B" | |
Score: X%. Feedback goes here. | |
--- | |
Intro paragraph: | |
{content}""" | |
BODY_PROMPT1 = """You are a university English teacher. Complete the following tasks for the following essay paragraph about a book: | |
1. Topic sentence: Identify the topic sentence and determine whether it introduces an argument | |
2. Key points: Outline a bullet list of key points | |
3. Supporting evidence: Give a bullet list of parts of the paragraph that use quotes or other textual evidence from the book | |
{content}""" | |
BODY_PROMPT2 = """4. Give advice on how the topic sentence could be made stronger or clearer | |
5. In a bullet list, state how each key point supports the topic (or if any doesn't support it) | |
6. In a bullet list for each supporting evidence, state which key point the evidence supports. | |
""" | |
BODY_PROMPT3 = """Briefly summarize "{title}". Then, in a bullet list for each supporting evidence you liisted above, state if it describes an event/detail from the "{title}" or if it's from outside sources. | |
Use this output format: | |
[summary] | |
---- | |
- [supporting evidence 1] - book | |
- [supporting evidence 2] - outside source""" | |
BODY_DESCRIPTION = """1. identifies the topic sentence | |
2. outlines key points | |
3. checks for supporting evidence (e.g., quotes, summaries, and concrete details) | |
4. suggests topic sentence improvements | |
5. checks that the key points match the paragraph topic | |
6. determines which key point each piece of evidence supports | |
7. checks whether each evidence is from the book or from an outside source""" | |
def is_empty(s: str): | |
return len(s) == 0 or s.isspace() | |
def check_content(s: str): | |
if is_empty(s): | |
raise gr.exceptions.Error('Please input some text before running.') | |
def load_chain(api_key, api_type): | |
if api_key == "" or api_key.isspace(): | |
if api_type == "OpenAI": | |
api_key = os.environ.get("OPENAI_API_KEY", None) | |
elif api_type == "Azure OpenAI": | |
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None) | |
else: | |
raise RuntimeError("Unknown API type? " + api_type) | |
if api_key: | |
shared_args = { | |
"temperature": 0, | |
"model_name": "gpt-3.5-turbo", | |
"api_key": api_key, # deliberately not use "openai_api_key" and other openai args since those apply globally | |
"pl_tags": ["grammar"], | |
"streaming": True, | |
} | |
if api_type == "OpenAI": | |
llm = PromptLayerChatOpenAI(**shared_args) | |
elif api_type == "Azure OpenAI": | |
llm = PromptLayerChatOpenAI( | |
api_type = "azure", | |
api_base = os.environ.get("AZURE_OPENAI_API_BASE", None), | |
api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"), | |
engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None), | |
**shared_args | |
) | |
prompt1 = PromptTemplate( | |
input_variables=["content"], | |
template=GRAMMAR_PROMPT | |
) | |
chain = LLMChain(llm=llm, | |
prompt=prompt1, | |
memory=ConversationBufferMemory()) | |
chain_intro = LLMChain(llm=llm, | |
prompt=PromptTemplate( | |
input_variables=["content"], | |
template=INTRO_PROMPT | |
), | |
memory=ConversationBufferMemory()) | |
chain_body1 = LLMChain(llm=llm, | |
prompt=PromptTemplate( | |
input_variables=["content"], | |
template=BODY_PROMPT1 | |
), | |
memory=ConversationBufferMemory()) | |
return chain, llm, chain_intro, chain_body1 | |
def run_diff(content, chain: LLMChain): | |
check_content(content) | |
chain.memory.clear() | |
edited = chain.run(content) | |
return diff_words(content, edited) + (edited,) | |
# https://github.com/hwchase17/langchain/issues/2428#issuecomment-1512280045 | |
def run(content, chain: LLMChain): | |
check_content(content) | |
chain.memory.clear() | |
q = Queue() | |
job_done = object() | |
def task(): | |
result = chain.run(content, callbacks=[SyncStreamingLLMCallbackHandler(q)]) | |
q.put(job_done) | |
return result | |
with start_blocking_portal() as portal: | |
portal.start_task_soon(task) | |
output = "" | |
while True: | |
next_token = q.get(True, timeout=10) | |
if next_token is job_done: | |
break | |
output += next_token | |
yield output | |
# TODO share code with above | |
def run_followup(followup_question, input_vars, chain, chat: ChatOpenAI): | |
check_content(followup_question) | |
history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m | |
for m in chain.memory.chat_memory.messages] | |
prompt = ChatPromptTemplate.from_messages([ | |
*history, | |
HumanMessagePromptTemplate.from_template(followup_question)]) | |
messages = prompt.format_prompt(**input_vars).to_messages() | |
q = Queue() | |
job_done = object() | |
def task(): | |
result = chat.generate([messages], callbacks=[SyncStreamingLLMCallbackHandler(q)]) | |
q.put(job_done) | |
return result.generations[0][0].message.content | |
with start_blocking_portal() as portal: | |
portal.start_task_soon(task) | |
output = "" | |
while True: | |
next_token = q.get(True, timeout=10) | |
if next_token is job_done: | |
break | |
output += next_token | |
yield output | |
def run_body(content, title, chain, llm): | |
check_content(content) # note: run() also checks, but the error doesn't get shown in the UI? | |
if not title: | |
return "Please enter the book title." | |
yield from concatenate_generators( | |
run(content, chain), | |
"\n\n", | |
run_followup(BODY_PROMPT2, {}, chain, llm), | |
"\n\n7. Whether supporting evidence is from the book:", | |
(output.split("----")[-1] for output in run_followup(BODY_PROMPT3, {"title": title}, chain, llm)) | |
) | |
def run_custom(content, llm, prompt): | |
chain = LLMChain(llm=llm, | |
memory=ConversationBufferMemory(), | |
prompt=PromptTemplate( | |
input_variables=["content"], | |
template=prompt | |
)) | |
return chain.run(content), chain | |
# not currently used | |
def split_paragraphs(text): | |
return [(x, x != "" and not x.startswith("#") and not x.isspace()) for x in text.split("\n")] | |
def sliding_window(iterable, n): | |
# sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG | |
it = iter(iterable) | |
window = collections.deque(islice(it, n), maxlen=n) | |
if len(window) == n: | |
yield tuple(window) | |
for x in it: | |
window.append(x) | |
yield tuple(window) | |
dmp = diff_match_patch() | |
def diff_words(content, edited): | |
before = [] | |
after = [] | |
changes = [] | |
change_count = 0 | |
changed = False | |
diff = dmp.diff_main(content, edited) | |
dmp.diff_cleanupSemantic(diff) | |
diff += [(None, None)] | |
for [(change, text), (next_change, next_text)] in sliding_window(diff, 2): | |
if change == 0: | |
before.append((text, None)) | |
after.append((text, None)) | |
else: | |
if change == -1 and next_change == 1: | |
change_count += 1 | |
before.append((text, str(change_count))) | |
after.append((next_text, str(change_count))) | |
changes.append((text, next_text)) | |
changed = True | |
elif change == -1: | |
before.append((text, "-")) | |
elif change == 1: | |
if changed: | |
changed = False | |
else: | |
after.append((text, "+")) | |
else: | |
raise Exception("Unknown change type: " + change) | |
return before, after, changes | |
def get_parts(arr, start, end): | |
return "".join(arr[start:end]) | |
CHANGES = { | |
"-": "remove", | |
"+": "add", | |
# "→": "change" | |
} | |
def select_diff(evt: gr.SelectData, changes): | |
text, change = evt.value | |
if not change: | |
return | |
change_text = CHANGES.get(change, None) | |
if change_text: | |
return f"Why is it better to {change_text} [{text}]?" | |
# if change == "→": | |
else: | |
# clicked = evt.target | |
# if clicked.label == "Before": | |
# original = text | |
# else: | |
# edited = text | |
original, edited = changes[int(change) - 1] | |
# original, edited = text.split("→") | |
return f"Why is it better to change [{original}] to [{edited}]?" | |
demo = gr.Blocks(css=""" | |
.diff-component { | |
white-space: pre-wrap; | |
} | |
.diff-component .textspan.hl { | |
white-space: normal; | |
} | |
""") | |
with demo: | |
# api_key = gr.Textbox( | |
# placeholder="Paste your OpenAPI API key here (sk-...)", | |
# show_label=False, | |
# lines=1, | |
# type="password" | |
# ) | |
api_key = gr.State("") | |
gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>""") | |
gr.Markdown("""Paste a paragraph below, and then choose one of the modes to generate feedback.""") | |
content = gr.Textbox( | |
label="Paragraph" | |
) | |
with gr.Tab("Grammar/Spelling"): | |
gr.Markdown("Suggests grammar and spelling revisions.") | |
submit = gr.Button( | |
value="Revise", | |
).style(full_width=False) | |
with gr.Row(): | |
output_before = gr.HighlightedText( | |
label="Before", | |
combine_adjacent=True, | |
elem_classes="diff-component" | |
).style(color_map={ | |
"-": "red", | |
# "→": "yellow", | |
}) | |
output_after = gr.HighlightedText( | |
label="After", | |
combine_adjacent=True, | |
elem_classes="diff-component" | |
).style(color_map={ | |
"+": "green", | |
# "→": "yellow", | |
}) | |
followup_question = gr.Textbox( | |
label="Follow-up Question", | |
) | |
followup_submit = gr.Button( | |
value="Ask" | |
).style(full_width=False) | |
followup_answer = gr.Textbox( | |
label="Answer" | |
) | |
with gr.Tab("Intro"): | |
gr.Markdown("Checks for the key components of an introductory paragraph.") | |
submit_intro = gr.Button( | |
value="Run" | |
).style(full_width=False) | |
output_intro = gr.Textbox( | |
label="Output", | |
lines=1000, | |
max_lines=1000 | |
) | |
with gr.Tab("Body Paragraph"): | |
gr.Markdown(BODY_DESCRIPTION) | |
title = gr.Textbox( | |
label="Book Title" | |
) | |
submit_body = gr.Button( | |
value="Run" | |
).style(full_width=False) | |
output_body = gr.Textbox( | |
label="Output", | |
lines=1000, | |
max_lines=1000 | |
) | |
# with gr.Tab("Custom prompt"): | |
# gr.Markdown("This mode is for testing and debugging.") | |
# prompt = gr.Textbox( | |
# label="Prompt", | |
# value=GRAMMAR_PROMPT, | |
# lines=2 | |
# ) | |
# submit_custom = gr.Button( | |
# value="Run" | |
# ).style(full_width=False) | |
# output_custom = gr.Textbox( | |
# label="Output" | |
# ) | |
# followup_custom = gr.Textbox( | |
# label="Follow-up Question" | |
# ) | |
# followup_answer_custom = gr.Textbox( | |
# label="Answer" | |
# ) | |
with gr.Tab("Settings"): | |
api_type = gr.Radio( | |
["OpenAI", "Azure OpenAI"], | |
value="OpenAI", | |
label="Server", | |
info="You can try changing this if responses are slow." | |
) | |
changes = gr.State() | |
edited = gr.State() | |
chain = gr.State() | |
llm = gr.State() | |
chain_intro = gr.State() | |
chain_body1 = gr.State() | |
chain_custom = gr.State() | |
# api_key.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1]) | |
api_type.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1]) | |
inputs = [content, chain] | |
outputs = [output_before, output_after, changes, edited] | |
# content.submit(run_diff, inputs=inputs, outputs=outputs) | |
submit.click(run_diff, inputs=inputs, outputs=outputs) | |
output_before.select(select_diff, changes, followup_question) | |
output_after.select(select_diff, changes, followup_question) | |
empty_input = gr.State({}) | |
inputs2 = [followup_question, empty_input, chain, llm] | |
outputs2 = followup_answer | |
followup_question.submit(run_followup, inputs2, outputs2) | |
followup_submit.click(run_followup, inputs2, outputs2) | |
submit_intro.click(run, [content, chain_intro], output_intro) | |
submit_body.click(run_body, [content, title, chain_body1, llm], output_body) # body part A only | |
# submit_custom.click(run_custom, [content, llm, prompt], [output_custom, chain_custom]) # TODO standardize api--return memory instead of using chain? | |
# followup_custom.submit(run_followup, [followup_custom, empty_input, chain_custom, llm], followup_answer_custom) | |
demo.load(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1]) | |
port = os.environ.get("SERVER_PORT", None) | |
if port: | |
port = int(port) | |
demo.queue() | |
demo.launch(debug=True, server_port=port, prevent_thread_lock=True) | |