File size: 13,940 Bytes
3070581
ed9313b
73f9c5b
ed9313b
4aaf91a
ed9313b
4aaf91a
3070581
ed9313b
0605bc2
4aaf91a
0605bc2
 
 
 
4aaf91a
2d94716
0605bc2
22cc962
 
 
f2608fd
22cc962
 
f2608fd
 
 
 
 
 
 
 
 
22cc962
33741fc
 
 
bfd1e8f
22cc962
 
 
 
bfd1e8f
33741fc
dc6b2e4
280e63f
 
 
bfd1e8f
 
12e8ac9
 
4c32c48
 
 
 
 
 
 
 
 
4aab74e
 
 
 
 
 
 
 
12e8ac9
 
 
 
 
 
 
 
 
 
0605bc2
12e8ac9
 
 
dd1d609
12e8ac9
4aaf91a
12e8ac9
 
 
 
dd1d609
 
 
 
 
12e8ac9
 
 
0605bc2
 
22cc962
0605bc2
 
 
 
22cc962
 
 
 
 
 
 
 
 
 
 
 
0605bc2
22cc962
0605bc2
 
4aaf91a
4aab74e
22cc962
4aaf91a
dfd1845
0605bc2
4aaf91a
 
4aab74e
22cc962
 
4aaf91a
 
2d94716
 
4aaf91a
 
 
 
 
 
10b507e
4aaf91a
 
 
 
10b507e
 
4aaf91a
 
 
4aab74e
12e8ac9
280e63f
 
 
 
 
 
4aaf91a
 
 
2d94716
 
4aaf91a
 
 
 
 
 
10b507e
4aaf91a
 
 
 
10b507e
 
4aaf91a
280e63f
 
4aab74e
280e63f
 
4aaf91a
 
 
 
 
 
 
 
22cc962
 
280e63f
 
 
 
 
 
 
22cc962
4aaf91a
0605bc2
4aaf91a
0605bc2
22cc962
dfd1845
 
 
 
 
 
 
 
 
 
0605bc2
dfd1845
 
 
 
 
 
 
 
4aaf91a
dfd1845
 
 
 
0605bc2
dfd1845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0605bc2
 
 
 
 
 
 
 
 
dfd1845
0605bc2
dfd1845
0605bc2
 
 
dfd1845
 
 
 
 
 
 
 
 
 
 
 
 
 
3070581
12e8ac9
22cc962
 
 
 
 
 
 
12e8ac9
d2ff475
 
 
 
 
 
12e8ac9
8d2dd70
d2ff475
0605bc2
22cc962
0605bc2
22cc962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfd1845
22cc962
 
 
 
 
dfd1845
22cc962
 
 
 
 
0605bc2
22cc962
ecec7b8
f7f4627
ecec7b8
22cc962
 
4c32c48
280e63f
 
 
22cc962
 
 
 
 
ecec7b8
f7f4627
ecec7b8
22cc962
d2ff475
 
 
 
 
 
 
 
 
 
 
 
 
 
0605bc2
280e63f
 
 
 
 
 
12e8ac9
 
 
 
 
 
 
280e63f
dfd1845
0605bc2
 
 
22cc962
 
0605bc2
280e63f
 
12e8ac9
 
 
0605bc2
 
dfd1845
22cc962
 
0605bc2
dfd1845
 
0605bc2
280e63f
 
 
0605bc2
280e63f
 
0605bc2
22cc962
280e63f
 
12e8ac9
280e63f
d2ff475
12e8ac9
22cc962
fc039dd
 
 
4aaf91a
8d2dd70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
# based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py
import collections
import os
from itertools import islice
from queue import Queue

from anyio.from_thread import start_blocking_portal
import gradio as gr
from diff_match_patch import diff_match_patch
from langchain.chains import LLMChain
from langchain.chat_models import PromptLayerChatOpenAI, ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import HumanMessage

from util import SyncStreamingLLMCallbackHandler, concatenate_generators

GRAMMAR_PROMPT = "Proofread for grammar and spelling without adding new paragraphs:\n{content}"

INTRO_PROMPT = """These are the parts of a good introductory paragraph:
1. Introductory information
2. The stage of human development of the main character
3. Summary of story
4. Thesis statement (this should also provide an overview the essay structure or topics that may be covered in each paragraph)
For each part, put a quote of the sentences from the following paragraph that fulfil that part and say how confident you are (percentage). If you're not confident, explain why.
---
Example output format:
Thesis statement and outline:
"Sentence A. Sentence B"
Score: X%. Feedback goes here.
---
Intro paragraph:
{content}"""
BODY_PROMPT1 = """You are a university English teacher. Complete the following tasks for the following essay paragraph about a book:
1. Topic sentence: Identify the topic sentence and determine whether it introduces an argument
2. Key points: Outline a bullet list of key points
3. Supporting evidence: Give a bullet list of parts of the paragraph that use quotes or other textual evidence from the book

{content}"""
BODY_PROMPT2 = """4. Give advice on how the topic sentence could be made stronger or clearer
5. In a bullet list, state how each key point supports the topic (or if any doesn't support it)
6. In a bullet list for each supporting evidence, state which key point the evidence supports.
"""
BODY_PROMPT3 = """Briefly summarize "{title}". Then, in a bullet list for each supporting evidence you liisted above, state if it describes an event/detail from the "{title}" or if it's from outside sources.
Use this output format:
[summary]
----
- [supporting evidence 1] - book
- [supporting evidence 2] - outside source"""


BODY_DESCRIPTION = """1. identifies the topic sentence
2. outlines key points
3. checks for supporting evidence (e.g., quotes, summaries, and concrete details)
4. suggests topic sentence improvements
5. checks that the key points match the paragraph topic
6. determines which key point each piece of evidence supports
7. checks whether each evidence is from the book or from an outside source"""


def is_empty(s: str):
  return len(s) == 0 or s.isspace()

def check_content(s: str):
  if is_empty(s):
    raise gr.exceptions.Error('Please input some text before running.')


def load_chain(api_key, api_type):
  if api_key == "" or api_key.isspace():
    if api_type == "OpenAI":
      api_key = os.environ.get("OPENAI_API_KEY", None)
    elif api_type == "Azure OpenAI":
      api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
    else:
      raise RuntimeError("Unknown API type? " + api_type)


  if api_key:
    shared_args = {
        "temperature": 0,
        "model_name": "gpt-3.5-turbo",
        "api_key": api_key, # deliberately not use "openai_api_key" and other openai args since those apply globally
        "pl_tags": ["grammar"],
        "streaming": True,
    }
    if api_type == "OpenAI":
      llm = PromptLayerChatOpenAI(**shared_args)
    elif api_type == "Azure OpenAI":
      llm = PromptLayerChatOpenAI(
        api_type = "azure",
        api_base = os.environ.get("AZURE_OPENAI_API_BASE", None),
        api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"),
        engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None),
        **shared_args
      )

    prompt1 = PromptTemplate(
        input_variables=["content"],
        template=GRAMMAR_PROMPT
    )
    chain = LLMChain(llm=llm,
                     prompt=prompt1,
                     memory=ConversationBufferMemory())
    chain_intro = LLMChain(llm=llm,
                           prompt=PromptTemplate(
                               input_variables=["content"],
                               template=INTRO_PROMPT
                           ),
                           memory=ConversationBufferMemory())
    chain_body1 = LLMChain(llm=llm,
                           prompt=PromptTemplate(
                               input_variables=["content"],
                               template=BODY_PROMPT1
                           ),
                           memory=ConversationBufferMemory())

    return chain, llm, chain_intro, chain_body1


def run_diff(content, chain: LLMChain):
  check_content(content)
  chain.memory.clear()
  edited = chain.run(content)
  return diff_words(content, edited) + (edited,)

# https://github.com/hwchase17/langchain/issues/2428#issuecomment-1512280045
def run(content, chain: LLMChain):
  check_content(content)
  chain.memory.clear()

  q = Queue()
  job_done = object()
  def task():
    result = chain.run(content, callbacks=[SyncStreamingLLMCallbackHandler(q)])
    q.put(job_done)
    return result

  with start_blocking_portal() as portal:
    portal.start_task_soon(task)

    output = ""
    while True:
      next_token = q.get(True, timeout=10)
      if next_token is job_done:
        break
      output += next_token
      yield output

# TODO share code with above
def run_followup(followup_question, input_vars, chain, chat: ChatOpenAI):
  check_content(followup_question)

  history = [HumanMessage(content=chain.prompt.format(content=m.content)) if isinstance(m, HumanMessage) else m
             for m in chain.memory.chat_memory.messages]
  prompt = ChatPromptTemplate.from_messages([
      *history,
      HumanMessagePromptTemplate.from_template(followup_question)])
  messages = prompt.format_prompt(**input_vars).to_messages()

  q = Queue()
  job_done = object()
  def task():
    result = chat.generate([messages], callbacks=[SyncStreamingLLMCallbackHandler(q)])
    q.put(job_done)
    return result.generations[0][0].message.content

  with start_blocking_portal() as portal:
    portal.start_task_soon(task)

    output = ""
    while True:
      next_token = q.get(True, timeout=10)
      if next_token is job_done:
        break
      output += next_token
      yield output


def run_body(content, title, chain, llm):
  check_content(content) # note: run() also checks, but the error doesn't get shown in the UI?
  if not title:
    return "Please enter the book title."

  yield from concatenate_generators(
    run(content, chain),
    "\n\n",
    run_followup(BODY_PROMPT2, {}, chain, llm),
    "\n\n7. Whether supporting evidence is from the book:",
    (output.split("----")[-1] for output in run_followup(BODY_PROMPT3, {"title": title}, chain, llm))
    )

def run_custom(content, llm, prompt):
  chain = LLMChain(llm=llm,
                   memory=ConversationBufferMemory(),
                   prompt=PromptTemplate(
                       input_variables=["content"],
                       template=prompt
                  ))
  return chain.run(content), chain

# not currently used
def split_paragraphs(text):
  return [(x, x != "" and not x.startswith("#") and not x.isspace()) for x in text.split("\n")]

def sliding_window(iterable, n):
    # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
    it = iter(iterable)
    window = collections.deque(islice(it, n), maxlen=n)
    if len(window) == n:
        yield tuple(window)
    for x in it:
        window.append(x)
        yield tuple(window)

dmp = diff_match_patch()
def diff_words(content, edited):
  before = []
  after = []
  changes = []
  change_count = 0
  changed = False
  diff = dmp.diff_main(content, edited)
  dmp.diff_cleanupSemantic(diff)
  diff += [(None, None)]

  for [(change, text), (next_change, next_text)] in sliding_window(diff, 2):
    if change == 0:
      before.append((text, None))
      after.append((text, None))
    else:
      if change == -1 and next_change == 1:
        change_count += 1
        before.append((text, str(change_count)))
        after.append((next_text, str(change_count)))
        changes.append((text, next_text))
        changed = True
      elif change == -1:
        before.append((text, "-"))
      elif change == 1:
        if changed:
          changed = False
        else:
          after.append((text, "+"))
      else:
        raise Exception("Unknown change type: " + change)

  return before, after, changes

def get_parts(arr, start, end):
  return "".join(arr[start:end])



CHANGES = {
   "-": "remove",
   "+": "add",
  #  "β†’": "change"
}
def select_diff(evt: gr.SelectData, changes):
  text, change = evt.value
  if not change:
    return
  change_text = CHANGES.get(change, None)
  if change_text:
    return f"Why is it better to {change_text} [{text}]?"
  # if change == "β†’":
  else:
    # clicked = evt.target
    # if clicked.label == "Before":
    #   original = text
    # else:
    #   edited = text

    original, edited = changes[int(change) - 1]
    # original, edited = text.split("β†’")
    return f"Why is it better to change [{original}] to [{edited}]?"

demo = gr.Blocks(css="""
.diff-component {
  white-space: pre-wrap;
}
.diff-component .textspan.hl {
  white-space: normal;
}
""")
with demo:
  # api_key = gr.Textbox(
  #     placeholder="Paste your OpenAPI API key here (sk-...)",
  #     show_label=False,
  #     lines=1,
  #     type="password"
  # )
  api_key = gr.State("")
  gr.HTML("""<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>""")
  gr.Markdown("""Paste a paragraph below, and then choose one of the modes to generate feedback.""")
  content = gr.Textbox(
      label="Paragraph"
  )

  with gr.Tab("Grammar/Spelling"):
    gr.Markdown("Suggests grammar and spelling revisions.")
    submit = gr.Button(
        value="Revise",
    ).style(full_width=False)

    with gr.Row():
      output_before = gr.HighlightedText(
          label="Before",
          combine_adjacent=True,
          elem_classes="diff-component"
      ).style(color_map={
          "-": "red",
          # "β†’": "yellow",
      })
      output_after = gr.HighlightedText(
          label="After",
          combine_adjacent=True,
          elem_classes="diff-component"
      ).style(color_map={
          "+": "green",
          # "β†’": "yellow",
      })

    followup_question = gr.Textbox(
        label="Follow-up Question",
    )
    followup_submit = gr.Button(
        value="Ask"
    ).style(full_width=False)
    followup_answer = gr.Textbox(
        label="Answer"
    )
  with gr.Tab("Intro"):
    gr.Markdown("Checks for the key components of an introductory paragraph.")
    submit_intro = gr.Button(
        value="Run"
    ).style(full_width=False)

    output_intro = gr.Textbox(
        label="Output",
        lines=1000,
        max_lines=1000
    )
  with gr.Tab("Body Paragraph"):
    gr.Markdown(BODY_DESCRIPTION)
    title = gr.Textbox(
        label="Book Title"
    )
    submit_body = gr.Button(
        value="Run"
    ).style(full_width=False)

    output_body = gr.Textbox(
        label="Output",
        lines=1000,
        max_lines=1000
    )
  # with gr.Tab("Custom prompt"):
  #   gr.Markdown("This mode is for testing and debugging.")
  #   prompt = gr.Textbox(
  #       label="Prompt",
  #       value=GRAMMAR_PROMPT,
  #       lines=2
  #   )
  #   submit_custom = gr.Button(
  #       value="Run"
  #   ).style(full_width=False)

  #   output_custom = gr.Textbox(
  #       label="Output"
  #   )

  #   followup_custom = gr.Textbox(
  #       label="Follow-up Question"
  #   )
  #   followup_answer_custom = gr.Textbox(
  #       label="Answer"
  #   )
  with gr.Tab("Settings"):
    api_type = gr.Radio(
      ["OpenAI", "Azure OpenAI"],
      value="OpenAI",
      label="Server",
      info="You can try changing this if responses are slow."
    )

  changes = gr.State()
  edited = gr.State()
  chain = gr.State()
  llm = gr.State()
  chain_intro = gr.State()
  chain_body1 = gr.State()

  chain_custom = gr.State()


  # api_key.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])
  api_type.change(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])

  inputs = [content, chain]
  outputs = [output_before, output_after, changes, edited]
  # content.submit(run_diff, inputs=inputs, outputs=outputs)
  submit.click(run_diff, inputs=inputs, outputs=outputs)

  output_before.select(select_diff, changes, followup_question)
  output_after.select(select_diff, changes, followup_question)

  empty_input = gr.State({})

  inputs2 = [followup_question, empty_input, chain, llm]
  outputs2 = followup_answer
  followup_question.submit(run_followup, inputs2, outputs2)
  followup_submit.click(run_followup, inputs2, outputs2)

  submit_intro.click(run, [content, chain_intro], output_intro)
  submit_body.click(run_body, [content, title, chain_body1, llm], output_body) # body part A only
  # submit_custom.click(run_custom, [content, llm, prompt], [output_custom, chain_custom]) # TODO standardize api--return memory instead of using chain?

  # followup_custom.submit(run_followup, [followup_custom, empty_input, chain_custom, llm], followup_answer_custom)

  demo.load(load_chain, [api_key, api_type], [chain, llm, chain_intro, chain_body1])

port = os.environ.get("SERVER_PORT", None)
if port:
  port = int(port)
demo.queue()
demo.launch(debug=True, server_port=port, prevent_thread_lock=True)