YDTsai commited on
Commit
cb8fdae
1 Parent(s): fe0ce42

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .ipynb_checkpoints
src/.ipynb_checkpoints/app-checkpoint.py DELETED
@@ -1,68 +0,0 @@
1
- import os
2
- import gradio as gr
3
- from suggest import Suggest
4
- from edit import Editor
5
- from config import configure_logging
6
-
7
-
8
- configure_logging()
9
-
10
-
11
- with gr.Blocks() as demo:
12
-
13
- title = gr.Button("PaperGPT", interactive=False)
14
- key = gr.Textbox(label="openai_key", value=os.environ.get('OPENAI_API_KEY'))
15
-
16
- with gr.Tab("Edit"):
17
-
18
- handler = Editor()
19
- txt_in = gr.Textbox(label="Input", lines=11, max_lines=11, value=handler.sample_content)
20
- btn = gr.Button("Edit")
21
- txt_out = gr.Textbox(label="Output", lines=11, max_lines=11, value="GPT will serve as your editor and modify the paragraph for you.")
22
- btn.click(handler.generate, inputs=[txt_in, key], outputs=[txt_out])
23
-
24
- with gr.Tab("Suggest"):
25
-
26
- idea_list = []
27
- max_ideas = 20
28
- handler = Suggest(max_ideas)
29
-
30
- def select(name: str):
31
- global idea_list
32
- for i in idea_list:
33
- if i['title'] == name:
34
- return [
35
- gr.Textbox.update(value=i["thought"], label="thought", visible=True),
36
- gr.Textbox.update(value=i["action"], label="action", visible=True),
37
- gr.Textbox.update(value=i["original"], label="original", visible=True, max_lines=5, lines=5),
38
- gr.Textbox.update(value=i["improved"], label="improved", visible=True, max_lines=5, lines=5)
39
- ]
40
-
41
- with gr.Row().style(equal_height=True):
42
- with gr.Column(scale=0.95):
43
- txt_in = gr.Textbox(label="Input", lines=11, max_lines=11, value=handler.sample_content[2048+2048+256-45:])
44
- with gr.Column(scale=0.05):
45
- upload = gr.File(file_count="single", file_types=["tex", ".pdf"])
46
- btn = gr.Button("Analyze")
47
- upload.change(handler.read_file, inputs=upload, outputs=txt_in)
48
-
49
- textboxes = []
50
- sug = gr.Textbox("GPT will give suggestions and help you improve the paper quality.", interactive=False, show_label=False).style(text_align="center")
51
- with gr.Row():
52
- with gr.Column(scale=0.4):
53
- for i in range(max_ideas):
54
- t = gr.Button("", visible=False)
55
- textboxes.append(t)
56
- with gr.Column(scale=0.6):
57
- thought = gr.Textbox(label="thought", visible=False, interactive=False)
58
- action = gr.Textbox(label="action", visible=False, interactive=False)
59
- original = gr.Textbox(label="original", visible=False, max_lines=5, lines=5, interactive=False)
60
- improved = gr.Textbox(label="improved", visible=False, max_lines=5, lines=5, interactive=False)
61
-
62
- btn.click(handler.generate, inputs=[txt_in, key], outputs=[sug, btn, thought, action, original, improved] + textboxes)
63
- for i in textboxes:
64
- i.click(select, inputs=[i], outputs=[thought, action, original, improved])
65
-
66
-
67
- # demo.launch(server_name="0.0.0.0", server_port=7653, share=True, enable_queue=True)
68
- demo.launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/.ipynb_checkpoints/config-checkpoint.py DELETED
@@ -1,17 +0,0 @@
1
- import logging
2
-
3
-
4
- def configure_logging(logging_level: str = "INFO"):
5
- logging_level = logging_level.upper()
6
- numeric_level = getattr(logging, logging_level, None)
7
-
8
- if not isinstance(numeric_level, int):
9
- raise Exception(f"Invalid log level: {numeric_level}")
10
-
11
- logging.basicConfig(
12
- level=numeric_level,
13
- datefmt="%Y-%m-%d %H:%M:%S",
14
- format=
15
- "[%(asctime)s] [%(process)s] [%(levelname)s] [%(module)s]: #%(funcName)s @%(lineno)d: %(message)s",
16
- )
17
- logging.info(f"Logging level: {logging_level}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/.ipynb_checkpoints/edit-checkpoint.py DELETED
@@ -1,43 +0,0 @@
1
- import logging
2
- import tiktoken
3
- import gradio as gr
4
- from langchain.text_splitter import CharacterTextSplitter
5
- from utils import fetch_chat
6
- from typing import List
7
-
8
-
9
- class Editor():
10
-
11
- def __init__(self, model: str = "gpt-3.5-turbo"):
12
- self.encoder = tiktoken.encoding_for_model(model)
13
- self.model = model
14
- with open("./sample/sample_abstract.tex", "r") as f:
15
- self.sample_content = f.read()
16
-
17
- def split_chunk(self, text, chunk_size: int = 2000) -> List[str]:
18
- text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
19
- chunk_size=100, chunk_overlap=0
20
- )
21
- text_list = text_splitter.split_text(text)
22
- return text_list
23
-
24
- def generate(self, text: str, openai_key: str):
25
-
26
- logging.info("start editing")
27
-
28
- try:
29
- prompt = f"""
30
- I am a computer science student.
31
- I am writing my research paper.
32
- You are my editor.
33
- Your goal is to improve my paper quality at your best.
34
- Please edit the following paragraph and return the modified paragraph.
35
- If the paragraph is written in latex, return the modified paragraph in latex.
36
-
37
- ```
38
- {text}
39
- ```
40
- """
41
- return fetch_chat(prompt, openai_key, model=self.model)
42
- except Exception as e:
43
- raise gr.Error(str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/.ipynb_checkpoints/papergpt-checkpoint.ipynb DELETED
@@ -1,491 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "raw",
5
- "id": "b77c4818-9d92-4581-a2fb-2b38da776a8b",
6
- "metadata": {},
7
- "source": [
8
- "# python >= 3.8\n",
9
- "import sys\n",
10
- "!{sys.executable} -m pip insatll langchain, gradio, tiktoken, unstructured"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": 1,
16
- "id": "80b2033d-b985-440d-a01f-e7f8547a6801",
17
- "metadata": {
18
- "tags": []
19
- },
20
- "outputs": [
21
- {
22
- "name": "stderr",
23
- "output_type": "stream",
24
- "text": [
25
- "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!\n",
26
- " warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n",
27
- "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
28
- " from .autonotebook import tqdm as notebook_tqdm\n"
29
- ]
30
- }
31
- ],
32
- "source": [
33
- "# from langchain.text_splitter import LatexTextSplitter\n",
34
- "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
35
- "from typing import Any\n",
36
- "import requests\n",
37
- "import logging\n",
38
- "import json\n",
39
- "import tiktoken\n",
40
- "import gradio as gr\n",
41
- "from langchain.document_loaders import UnstructuredPDFLoader"
42
- ]
43
- },
44
- {
45
- "cell_type": "code",
46
- "execution_count": 2,
47
- "id": "720927dd-848e-47ff-a847-4e1097768561",
48
- "metadata": {
49
- "tags": []
50
- },
51
- "outputs": [],
52
- "source": [
53
- "turbo_encoding = tiktoken.encoding_for_model(\"gpt-3.5-turbo\")\n",
54
- "with open(\"sample.tex\", \"r\") as f:\n",
55
- " content = f.read()"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": 3,
61
- "id": "13bc48fd-a048-4635-99c4-79ddb62d5c4d",
62
- "metadata": {
63
- "tags": []
64
- },
65
- "outputs": [],
66
- "source": [
67
- "class LatexTextSplitter(RecursiveCharacterTextSplitter):\n",
68
- " \"\"\"Attempts to split the text along Latex-formatted layout elements.\"\"\"\n",
69
- "\n",
70
- " def __init__(self, **kwargs: Any):\n",
71
- " \"\"\"Initialize a LatexTextSplitter.\"\"\"\n",
72
- " separators = [\n",
73
- " # First, try to split along Latex sections\n",
74
- " \"\\chapter{\",\n",
75
- " \"\\section{\",\n",
76
- " \"\\subsection{\",\n",
77
- " \"\\subsubsection{\",\n",
78
- "\n",
79
- " # Now split by environments\n",
80
- " \"\\begin{\"\n",
81
- " # \"\\n\\\\begin{enumerate}\",\n",
82
- " # \"\\n\\\\begin{itemize}\",\n",
83
- " # \"\\n\\\\begin{description}\",\n",
84
- " # \"\\n\\\\begin{list}\",\n",
85
- " # \"\\n\\\\begin{quote}\",\n",
86
- " # \"\\n\\\\begin{quotation}\",\n",
87
- " # \"\\n\\\\begin{verse}\",\n",
88
- " # \"\\n\\\\begin{verbatim}\",\n",
89
- "\n",
90
- " ## Now split by math environments\n",
91
- " # \"\\n\\\\begin{align}\",\n",
92
- " # \"$$\",\n",
93
- " # \"$\",\n",
94
- "\n",
95
- " # Now split by the normal type of lines\n",
96
- " \" \",\n",
97
- " \"\",\n",
98
- " ]\n",
99
- " super().__init__(separators=separators, **kwargs)\n",
100
- "\n",
101
- "\n",
102
- "def json_validator(text: str, openai_key: str, retry: int = 3):\n",
103
- " for _ in range(retry):\n",
104
- " try:\n",
105
- " return json.loads(text)\n",
106
- " except Exception:\n",
107
- " \n",
108
- " try:\n",
109
- " prompt = f\"Modify the following into a valid json format:\\n{text}\"\n",
110
- " prompt_token_length = len(turbo_encoding.encode(prompt))\n",
111
- "\n",
112
- " data = {\n",
113
- " \"model\": \"text-davinci-003\",\n",
114
- " \"prompt\": prompt,\n",
115
- " \"max_tokens\": 4097 - prompt_token_length - 64\n",
116
- " }\n",
117
- " headers = {\n",
118
- " \"Content-Type\": \"application/json\",\n",
119
- " \"Authorization\": f\"Bearer {openai_key}\"\n",
120
- " }\n",
121
- " for _ in range(retry):\n",
122
- " response = requests.post(\n",
123
- " 'https://api.openai.com/v1/completions',\n",
124
- " json=data,\n",
125
- " headers=headers,\n",
126
- " timeout=300\n",
127
- " )\n",
128
- " if response.status_code != 200:\n",
129
- " logging.warning(f'fetch openai chat retry: {response.text}')\n",
130
- " continue\n",
131
- " text = response.json()['choices'][0]['text']\n",
132
- " break\n",
133
- " except:\n",
134
- " return response.json()['error']\n",
135
- " \n",
136
- " return text"
137
- ]
138
- },
139
- {
140
- "cell_type": "code",
141
- "execution_count": 9,
142
- "id": "a34bf526-fb0c-4b5e-8c3d-bab7a13e5fe7",
143
- "metadata": {
144
- "tags": []
145
- },
146
- "outputs": [],
147
- "source": [
148
- "def analyze(latex_whole_document: str, openai_key: str, progress):\n",
149
- " \n",
150
- " logging.info(\"start analysis\")\n",
151
- " \n",
152
- " output_format = \"\"\"\n",
153
- "\n",
154
- " ```json\n",
155
- " [\n",
156
- " \\\\ Potential point for improvement 1\n",
157
- " {{\n",
158
- " \"title\": string \\\\ What this modification is about\n",
159
- " \"thought\": string \\\\ The reason why this should be improved\n",
160
- " \"action\": string \\\\ how to make improvement\n",
161
- " \"original\": string \\\\ the original latex snippet that can be improved\n",
162
- " \"improved\": string \\\\ the improved latex snippet which address your point\n",
163
- " }},\n",
164
- " {{}}\n",
165
- " ]\n",
166
- " ```\n",
167
- " \"\"\"\n",
168
- " \n",
169
- " chunk_size = 1000\n",
170
- " # for _ in range(5):\n",
171
- " # try:\n",
172
- " # latex_splitter = LatexTextSplitter(\n",
173
- " # chunk_size=min(chunk_size, len(latex_whole_document)),\n",
174
- " # chunk_overlap=0,\n",
175
- " # )\n",
176
- " # docs = latex_splitter.create_documents([latex_whole_document])\n",
177
- " # break\n",
178
- " # except:\n",
179
- " # chunk_size // 2\n",
180
- "\n",
181
- " latex_splitter = LatexTextSplitter(\n",
182
- " chunk_size=min(chunk_size, len(latex_whole_document)),\n",
183
- " chunk_overlap=0,\n",
184
- " )\n",
185
- " docs = latex_splitter.create_documents([latex_whole_document])\n",
186
- " \n",
187
- " progress(0.05)\n",
188
- " ideas = []\n",
189
- " for doc in progress.tqdm(docs):\n",
190
- "\n",
191
- " prompt = f\"\"\"\n",
192
- " I'm a computer science student.\n",
193
- " You are my editor.\n",
194
- " Your goal is to improve my paper quality at your best.\n",
195
- " \n",
196
- " \n",
197
- " ```\n",
198
- " {doc.page_content}\n",
199
- " ```\n",
200
- " The above is a segment of my research paper. If the end of the segment is not complete, just ignore it.\n",
201
- " Point out the parts that can be improved.\n",
202
- " Focus on grammar, writing, content, section structure.\n",
203
- " Ignore comments and those that are outside the document environment.\n",
204
- " List out all the points with a latex snippet which is the improved version addressing your point.\n",
205
- " Same paragraph should be only address once.\n",
206
- " Output the response in the following valid json format:\n",
207
- " {output_format}\n",
208
- "\n",
209
- " \"\"\"\n",
210
- " \n",
211
- " idea = fetch_chat(prompt, openai_key)\n",
212
- " if isinstance(idea, list):\n",
213
- " ideas += idea\n",
214
- " break\n",
215
- " else:\n",
216
- " raise gr.Error(idea)\n",
217
- "\n",
218
- " logging.info('complete analysis')\n",
219
- " return ideas\n",
220
- "\n",
221
- "\n",
222
- "def fetch_chat(prompt: str, openai_key: str, retry: int = 3):\n",
223
- " json = {\n",
224
- " \"model\": \"gpt-3.5-turbo-16k\",\n",
225
- " \"messages\": [{\"role\": \"user\", \"content\": prompt}]\n",
226
- " }\n",
227
- " headers = {\n",
228
- " \"Content-Type\": \"application/json\",\n",
229
- " \"Authorization\": f\"Bearer {openai_key}\"\n",
230
- " }\n",
231
- " for _ in range(retry):\n",
232
- " response = requests.post(\n",
233
- " 'https://api.openai.com/v1/chat/completions',\n",
234
- " json=json,\n",
235
- " headers=headers,\n",
236
- " timeout=300\n",
237
- " )\n",
238
- " if response.status_code != 200:\n",
239
- " logging.warning(f'fetch openai chat retry: {response.text}')\n",
240
- " continue\n",
241
- " result = response.json()['choices'][0]['message']['content']\n",
242
- " return json_validator(result, openai_key)\n",
243
- " \n",
244
- " return response.json()[\"error\"]\n",
245
- " \n",
246
- " \n",
247
- "def read_file(f: str):\n",
248
- " if f is None:\n",
249
- " return \"\"\n",
250
- " elif f.name.endswith('pdf'):\n",
251
- " loader = UnstructuredPDFLoader(f.name)\n",
252
- " pages = loader.load_and_split()\n",
253
- " return \"\\n\".join([p.page_content for p in pages])\n",
254
- " elif f.name.endswith('tex'):\n",
255
- " with open(f.name, \"r\") as f:\n",
256
- " return f.read()\n",
257
- " else:\n",
258
- " return \"Only support .tex & .pdf\""
259
- ]
260
- },
261
- {
262
- "cell_type": "code",
263
- "execution_count": 11,
264
- "id": "cec63e87-9741-4596-a3f1-a901830e3771",
265
- "metadata": {
266
- "tags": []
267
- },
268
- "outputs": [
269
- {
270
- "name": "stderr",
271
- "output_type": "stream",
272
- "text": [
273
- "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/button.py:112: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n",
274
- " warnings.warn(\n",
275
- "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/layouts.py:80: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n",
276
- " warnings.warn(\n",
277
- "/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/components/textbox.py:259: UserWarning: The `style` method is deprecated. Please set these arguments in the constructor instead.\n",
278
- " warnings.warn(\n"
279
- ]
280
- },
281
- {
282
- "name": "stdout",
283
- "output_type": "stream",
284
- "text": [
285
- "Running on local URL: http://0.0.0.0:7653\n",
286
- "Running on public URL: https://73992a9ff20adf33a3.gradio.live\n",
287
- "\n",
288
- "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
289
- ]
290
- },
291
- {
292
- "data": {
293
- "text/html": [
294
- "<div><iframe src=\"https://73992a9ff20adf33a3.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
295
- ],
296
- "text/plain": [
297
- "<IPython.core.display.HTML object>"
298
- ]
299
- },
300
- "metadata": {},
301
- "output_type": "display_data"
302
- },
303
- {
304
- "name": "stderr",
305
- "output_type": "stream",
306
- "text": [
307
- "WARNING:root:fetch openai chat retry: {\n",
308
- " \"error\": {\n",
309
- " \"message\": \"\",\n",
310
- " \"type\": \"invalid_request_error\",\n",
311
- " \"param\": null,\n",
312
- " \"code\": \"invalid_api_key\"\n",
313
- " }\n",
314
- "}\n",
315
- "\n",
316
- "WARNING:root:fetch openai chat retry: {\n",
317
- " \"error\": {\n",
318
- " \"message\": \"\",\n",
319
- " \"type\": \"invalid_request_error\",\n",
320
- " \"param\": null,\n",
321
- " \"code\": \"invalid_api_key\"\n",
322
- " }\n",
323
- "}\n",
324
- "\n",
325
- "WARNING:root:fetch openai chat retry: {\n",
326
- " \"error\": {\n",
327
- " \"message\": \"\",\n",
328
- " \"type\": \"invalid_request_error\",\n",
329
- " \"param\": null,\n",
330
- " \"code\": \"invalid_api_key\"\n",
331
- " }\n",
332
- "}\n",
333
- "\n",
334
- "Traceback (most recent call last):\n",
335
- " File \"/tmp/ipykernel_22031/279099274.py\", line 14, in generate\n",
336
- " idea_list = analyze(txt, openai_key, progress)\n",
337
- " File \"/tmp/ipykernel_22031/3345783910.py\", line 69, in analyze\n",
338
- " raise gr.Error(idea)\n",
339
- "gradio.exceptions.Error: {'message': '', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}\n",
340
- "\n",
341
- "During handling of the above exception, another exception occurred:\n",
342
- "\n",
343
- "Traceback (most recent call last):\n",
344
- " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/routes.py\", line 437, in run_predict\n",
345
- " output = await app.get_blocks().process_api(\n",
346
- " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/blocks.py\", line 1352, in process_api\n",
347
- " result = await self.call_function(\n",
348
- " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/gradio/blocks.py\", line 1077, in call_function\n",
349
- " prediction = await anyio.to_thread.run_sync(\n",
350
- " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/anyio/to_thread.py\", line 31, in run_sync\n",
351
- " return await get_asynclib().run_sync_in_worker_thread(\n",
352
- " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 937, in run_sync_in_worker_thread\n",
353
- " return await future\n",
354
- " File \"/home/hsiao1229/.local/share/virtualenvs/chatGPT-yp18Rznv/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 867, in run\n",
355
- " result = context.run(func, *args)\n",
356
- " File \"/tmp/ipykernel_22031/279099274.py\", line 37, in generate\n",
357
- " raise gr.Error(str(e))\n",
358
- "gradio.exceptions.Error: \"{'message': '', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}\"\n"
359
- ]
360
- }
361
- ],
362
- "source": [
363
- "idea_list = []\n",
364
- "max_ideas = 20\n",
365
- "\n",
366
- "\n",
367
- "with gr.Blocks() as demo:\n",
368
- " \n",
369
- " def generate(txt: str, openai_key: str, progress=gr.Progress()):\n",
370
- " \n",
371
- " if not openai_key:\n",
372
- " raise gr.Error(\"Please provide openai key !\")\n",
373
- " \n",
374
- " try:\n",
375
- " global idea_list\n",
376
- " idea_list = analyze(txt, openai_key, progress)\n",
377
- " k = min(len(idea_list), max_ideas)\n",
378
- "\n",
379
- " idea_buttons = [\n",
380
- " gr.Button.update(visible=True, value=i['title'])\n",
381
- " for e, i in enumerate(idea_list[:max_ideas])\n",
382
- " ]\n",
383
- " idea_buttons += [\n",
384
- " gr.Button.update(visible=False)\n",
385
- " ]*(max_ideas-len(idea_buttons))\n",
386
- "\n",
387
- " idea_details = [\n",
388
- " gr.Textbox.update(value=\"\", label=\"thought\", visible=True),\n",
389
- " gr.Textbox.update(value=\"\", label=\"action\", visible=True),\n",
390
- " gr.Textbox.update(value=\"\", label=\"original\", visible=True, max_lines=5, lines=5),\n",
391
- " gr.Textbox.update(value=\"\", label=\"improved\", visible=True, max_lines=5, lines=5)\n",
392
- " ]\n",
393
- "\n",
394
- " return [\n",
395
- " gr.Textbox.update(\"Suggestions\", interactive=False, show_label=False),\n",
396
- " gr.Button.update(visible=True, value=\"Analyze\")\n",
397
- " ] + idea_details + idea_buttons\n",
398
- " except Exception as e:\n",
399
- " raise gr.Error(str(e))\n",
400
- "\n",
401
- " def select(name: str):\n",
402
- " global idea_list\n",
403
- " for i in idea_list:\n",
404
- " if i['title'] == name:\n",
405
- " return [\n",
406
- " gr.Textbox.update(value=i[\"thought\"], label=\"thought\", visible=True),\n",
407
- " gr.Textbox.update(value=i[\"action\"], label=\"action\", visible=True),\n",
408
- " gr.Textbox.update(value=i[\"original\"], label=\"original\", visible=True, max_lines=5, lines=5),\n",
409
- " gr.Textbox.update(value=i[\"improved\"], label=\"improved\", visible=True, max_lines=5, lines=5)\n",
410
- " ]\n",
411
- " \n",
412
- " title = gr.Button(\"PaperGPT\", interactive=False).style(size=10)\n",
413
- " key = gr.Textbox(label=\"openai_key\")\n",
414
- " with gr.Row().style(equal_height=True):\n",
415
- " with gr.Column(scale=0.95):\n",
416
- " txt_in = gr.Textbox(label=\"Input\", lines=11, max_lines=11, value=content[2048+2048+256-45:])\n",
417
- " with gr.Column(scale=0.05):\n",
418
- " upload = gr.File(file_count=\"single\", file_types=[\"tex\", \".pdf\"])\n",
419
- " btn = gr.Button(\"Analyze\")\n",
420
- " upload.change(read_file, inputs=upload, outputs=txt_in)\n",
421
- "\n",
422
- " textboxes = []\n",
423
- " sug = gr.Textbox(\"Suggestions\", interactive=False, show_label=False).style(text_align=\"center\")\n",
424
- " with gr.Row():\n",
425
- " with gr.Column(scale=0.4):\n",
426
- " for i in range(max_ideas):\n",
427
- " t = gr.Button(\"\", visible=False)\n",
428
- " textboxes.append(t)\n",
429
- " with gr.Column(scale=0.6):\n",
430
- " thought = gr.Textbox(label=\"thought\", visible=False, interactive=False)\n",
431
- " action = gr.Textbox(label=\"action\", visible=False, interactive=False)\n",
432
- " original = gr.Textbox(label=\"original\", visible=False, max_lines=5, lines=5, interactive=False)\n",
433
- " improved = gr.Textbox(label=\"improved\", visible=False, max_lines=5, lines=5, interactive=False)\n",
434
- "\n",
435
- " btn.click(generate, inputs=[txt_in, key], outputs=[sug, btn, thought, action, original, improved] + textboxes)\n",
436
- " for i in textboxes:\n",
437
- " i.click(select, inputs=[i], outputs=[thought, action, original, improved])\n",
438
- " demo.launch(server_name=\"0.0.0.0\", server_port=7653, share=True, enable_queue=True)"
439
- ]
440
- },
441
- {
442
- "cell_type": "code",
443
- "execution_count": 10,
444
- "id": "8ac8aa92-f7a6-480c-a1b9-2f1c61426846",
445
- "metadata": {
446
- "tags": []
447
- },
448
- "outputs": [
449
- {
450
- "name": "stdout",
451
- "output_type": "stream",
452
- "text": [
453
- "Closing server running on port: 7653\n"
454
- ]
455
- }
456
- ],
457
- "source": [
458
- "demo.close()"
459
- ]
460
- },
461
- {
462
- "cell_type": "code",
463
- "execution_count": null,
464
- "id": "c9a19815-b8de-4a99-9fcf-0b1a0d3981a3",
465
- "metadata": {},
466
- "outputs": [],
467
- "source": []
468
- }
469
- ],
470
- "metadata": {
471
- "kernelspec": {
472
- "display_name": "Python 3 (ipykernel)",
473
- "language": "python",
474
- "name": "python3"
475
- },
476
- "language_info": {
477
- "codemirror_mode": {
478
- "name": "ipython",
479
- "version": 3
480
- },
481
- "file_extension": ".py",
482
- "mimetype": "text/x-python",
483
- "name": "python",
484
- "nbconvert_exporter": "python",
485
- "pygments_lexer": "ipython3",
486
- "version": "3.8.16"
487
- }
488
- },
489
- "nbformat": 4,
490
- "nbformat_minor": 5
491
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/.ipynb_checkpoints/suggest-checkpoint.py DELETED
@@ -1,176 +0,0 @@
1
- import requests
2
- import logging
3
- import json
4
- import tiktoken
5
- import gradio as gr
6
- from typing import Any, List
7
- from langchain.schema import Document
8
- from langchain.document_loaders import UnstructuredPDFLoader
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter
10
-
11
- from utils import json_validator, fetch_chat
12
-
13
-
14
- class LatexTextSplitter(RecursiveCharacterTextSplitter):
15
- """Attempts to split the text along Latex-formatted layout elements."""
16
-
17
- def __init__(self, **kwargs: Any):
18
- """Initialize a LatexTextSplitter."""
19
- separators = [
20
- # First, try to split along Latex sections
21
- "\chapter{",
22
- "\section{",
23
- "\subsection{",
24
- "\subsubsection{",
25
-
26
- # Now split by environments
27
- "\begin{"
28
- # "\n\\begin{enumerate}",
29
- # "\n\\begin{itemize}",
30
- # "\n\\begin{description}",
31
- # "\n\\begin{list}",
32
- # "\n\\begin{quote}",
33
- # "\n\\begin{quotation}",
34
- # "\n\\begin{verse}",
35
- # "\n\\begin{verbatim}",
36
-
37
- ## Now split by math environments
38
- # "\n\\begin{align}",
39
- # "$$",
40
- # "$",
41
-
42
- # Now split by the normal type of lines
43
- " ",
44
- "",
45
- ]
46
- super().__init__(separators=separators, **kwargs)
47
-
48
-
49
- class Suggest():
50
-
51
- def __init__(self, max_ideas: int, model: str = "gpt-3.5-turbo"):
52
- self.max_ideas = max_ideas
53
- self.encoder = tiktoken.encoding_for_model(model)
54
- self.model = model
55
- with open("./sample/sample.tex", "r") as f:
56
- self.sample_content = f.read()
57
-
58
- def split_chunk(self, latex_whole_document: str, chunk_size: int = 2000, retry: int = 5) -> List[Document]:
59
-
60
- chunk_size = min(chunk_size, len(latex_whole_document))
61
-
62
- for _ in range(retry):
63
- try:
64
- latex_splitter = LatexTextSplitter(
65
- chunk_size=chunk_size,
66
- chunk_overlap=0,
67
- )
68
- docs = latex_splitter.create_documents([latex_whole_document])
69
- return docs
70
- except:
71
- chunk_size = chunk_size // 2
72
-
73
- raise Exception("Latex document split check failed.")
74
-
75
- def analyze(self, latex_whole_document: str, openai_key: str, progress: gr.Progress):
76
-
77
- logging.info("start analysis")
78
- docs = self.split_chunk(latex_whole_document)
79
- progress(0.05)
80
-
81
- output_format = """
82
-
83
- ```json
84
- [
85
- \\ Potential point for improvement 1
86
- {{
87
- "title": string \\ What this modification is about
88
- "thought": string \\ The reason why this should be improved
89
- "action": string \\ how to make improvement
90
- "original": string \\ the original latex snippet that can be improved
91
- "improved": string \\ the improved latex snippet which address your point
92
- }},
93
- {{}}
94
- ]
95
- ```
96
- """
97
-
98
- ideas = []
99
- for doc in progress.tqdm(docs):
100
-
101
- prompt = f"""
102
- I'm a computer science student.
103
- You are my editor.
104
- Your goal is to improve my paper quality at your best.
105
-
106
-
107
- ```
108
- {doc.page_content}
109
- ```
110
- The above is a segment of my research paper. If the end of the segment is not complete, just ignore it.
111
- Point out the parts that can be improved.
112
- Focus on grammar, writing, content, section structure.
113
- Ignore comments and those that are outside the document environment.
114
- List out all the points with a latex snippet which is the improved version addressing your point.
115
- Same paragraph should be only address once.
116
- Output the response in the following valid json format:
117
- {output_format}
118
-
119
- """
120
-
121
- idea = fetch_chat(prompt, openai_key, model=self.model)
122
- idea = json_validator(idea, openai_key)
123
- if isinstance(idea, list):
124
- ideas += idea
125
- if len(ideas) >= self.max_ideas:
126
- break
127
- else:
128
- raise gr.Error(idea)
129
-
130
- logging.info('complete analysis')
131
- return ideas
132
-
133
- def read_file(self, f: str):
134
- if f is None:
135
- return ""
136
- elif f.name.endswith('pdf'):
137
- loader = UnstructuredPDFLoader(f.name)
138
- pages = loader.load_and_split()
139
- return "\n".join([p.page_content for p in pages])
140
- elif f.name.endswith('tex'):
141
- with open(f.name, "r") as f:
142
- return f.read()
143
- else:
144
- return "Only support .tex & .pdf"
145
-
146
- def generate(self, txt: str, openai_key: str, progress=gr.Progress()):
147
-
148
- if not openai_key:
149
- raise gr.Error("Please provide openai key !")
150
-
151
- try:
152
- global idea_list
153
- idea_list = self.analyze(txt, openai_key, progress)
154
- k = min(len(idea_list), self.max_ideas)
155
-
156
- idea_buttons = [
157
- gr.Button.update(visible=True, value=i['title'])
158
- for e, i in enumerate(idea_list[:self.max_ideas])
159
- ]
160
- idea_buttons += [
161
- gr.Button.update(visible=False)
162
- ] * (self.max_ideas - len(idea_buttons))
163
-
164
- idea_details = [
165
- gr.Textbox.update(value="", label="thought", visible=True),
166
- gr.Textbox.update(value="", label="action", visible=True),
167
- gr.Textbox.update(value="", label="original", visible=True, max_lines=5, lines=5),
168
- gr.Textbox.update(value="", label="improved", visible=True, max_lines=5, lines=5)
169
- ]
170
-
171
- return [
172
- gr.Textbox.update("Suggestions", interactive=False, show_label=False),
173
- gr.Button.update(visible=True, value="Analyze")
174
- ] + idea_details + idea_buttons
175
- except Exception as e:
176
- raise gr.Error(str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/.ipynb_checkpoints/utils-checkpoint.py DELETED
@@ -1,79 +0,0 @@
1
- import json
2
- import requests
3
- import tiktoken
4
- import logging
5
-
6
-
7
- def json_validator(
8
- text: str,
9
- openai_key: str,
10
- retry: int = 3,
11
- model: str = "text-davinci-003"
12
- ):
13
-
14
- encoder = tiktoken.encoding_for_model(model)
15
-
16
- for _ in range(retry):
17
- try:
18
- return json.loads(text)
19
- except Exception:
20
-
21
- try:
22
- prompt = f"Modify the following into a valid json format:\n{text}"
23
- prompt_token_length = len(encoder.encode(prompt))
24
-
25
- data = {
26
- "model": model,
27
- "prompt": prompt,
28
- "max_tokens": 4097 - prompt_token_length - 64
29
- }
30
- headers = {
31
- "Content-Type": "application/json",
32
- "Authorization": f"Bearer {openai_key}"
33
- }
34
- for _ in range(retry):
35
- response = requests.post(
36
- 'https://api.openai.com/v1/completions',
37
- json=data,
38
- headers=headers,
39
- timeout=300
40
- )
41
- if response.status_code != 200:
42
- logging.warning(f'fetch openai chat retry: {response.text}')
43
- continue
44
- text = response.json()['choices'][0]['text']
45
- break
46
- except Exception:
47
- return response.json()['error']
48
-
49
- return text
50
-
51
-
52
- def fetch_chat(
53
- prompt: str,
54
- openai_key: str,
55
- retry: int = 3,
56
- model: str = "gpt-3.5-turbo-16k"
57
- ):
58
- data = {
59
- "model": model,
60
- "messages": [{"role": "user", "content": prompt}]
61
- }
62
- headers = {
63
- "Content-Type": "application/json",
64
- "Authorization": f"Bearer {openai_key}"
65
- }
66
- for _ in range(retry):
67
- response = requests.post(
68
- 'https://api.openai.com/v1/chat/completions',
69
- json=data,
70
- headers=headers,
71
- timeout=300
72
- )
73
- if response.status_code != 200:
74
- logging.warning(f'fetch openai chat retry: {response.text}')
75
- continue
76
- result = response.json()['choices'][0]['message']['content']
77
- return result
78
-
79
- return response.json()["error"]